From a636900b9c74915dafe69f2fbac376217c361bb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elias=20J=C3=A4gerskogh?= Date: Fri, 25 Jul 2025 17:20:48 +0200 Subject: [PATCH 1/6] Added possiblity to reverse the flow. --- README.md | 2 ++ scripts/reverse.jl | 79 ++++++++++++++++++++++++++++++++++++++++++++++ src/ChainStorm.jl | 23 ++++++++++++++ src/flow.jl | 78 +++++++++++++++++++++++++++++++++++++++++---- 4 files changed, 176 insertions(+), 6 deletions(-) create mode 100644 scripts/reverse.jl diff --git a/README.md b/README.md index b7680c9..246c9fc 100644 --- a/README.md +++ b/README.md @@ -112,3 +112,5 @@ for epoch in 1:100 jldsave("model_epoch_$epoch.jld", model_state = Flux.state(cpu(model)), opt_state=cpu(opt_state)) end ``` + +The code written __will__ break the other examples. diff --git a/scripts/reverse.jl b/scripts/reverse.jl new file mode 100644 index 0000000..51c7bd6 --- /dev/null +++ b/scripts/reverse.jl @@ -0,0 +1,79 @@ +using Pkg +Pkg.add(["GLMakie", "ProtPlot", "ProgressBars"]) + +using ChainStorm, Flowfusion, GLMakie, ProtPlot, ProgressBars + +@eval Flowfusion begin + function reverse_gen(P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, model, steps::AbstractVector, record; tracker::Function=Returns(nothing), midpoint = false, progress_bar=identity) + Xₜ = copy.(X₀) + push!(record, (1, X₀, nothing)) + for (s₁, s₂) in progress_bar(zip(steps, steps[begin+1:end])) + T = eltype(s₁) + s2 = s₂ + s1 = s₁ == s₂ ? s2 - T(0.001) : s₁ + t = midpoint ? (s1 + s2) / 2 : t = s1 + X0hat, X1hat = model(t, Xₜ) + X0hat = resolveprediction(X0hat, Xₜ) + X1hat = resolveprediction(X1hat, Xₜ) + Xₜ = mask(step(P, Xₜ, X0hat, s1, s2), X₀) + + push!(record, (1-s₂, Xₜ, X1hat)) #records all the steps + tracker(1-t, Xₜ, X1hat) + end + return Xₜ + end + function gen(P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, model, steps::AbstractVector; tracker::Function=Returns(nothing), midpoint = false, progress_bar=identity) + Xₜ = copy.(X₀) + for (s₁, s₂) in progress_bar(zip(steps, steps[begin+1:end])) + t = midpoint ? (s₁ + s₂) / 2 : t = s₁ + hat = resolveprediction(model(t, Xₜ), Xₜ) + Xₜ = mask(step(P, Xₜ, hat, s₁, s₂), X₀) + tracker(t, Xₜ, hat) + end + return Xₜ + end + + function bind_gen(P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, model, steps::AbstractVector, record; tracker::Function=Returns(nothing), midpoint = false, progress_bar = identity) + Xₜ = copy.(X₀) + for (s₁, s₂) in progress_bar(zip(steps, steps[begin+1:end])) + t = midpoint ? (s₁ + s₂) / 2 : t = s₁ + + hat = resolveprediction(model(t, Xₜ), Xₜ) + + #Changes xt + s₁, old_xt, _ = record[end] + pop!(record) + tensor(Xₜ[1])[:, :, 1:size(tensor(old_xt[1]), 3), :] .= tensor(old_xt[1]) + tensor(Xₜ[2])[:, :, 1:size(tensor(old_xt[2]), 3), :] .= tensor(old_xt[2]) + tensor(Xₜ[3]).indices[1:size(tensor(old_xt[3]).indices, 1), :] .= tensor(old_xt[3]).indices + + + Xₜ = mask(step(P, Xₜ, hat, s₁, s₂), X₀) + + if length(record) == 1 + tensor(Xₜ[1])[:, :, 1:size(tensor(old_xt[1]), 3), :] .= tensor(old_xt[1]) + tensor(Xₜ[2])[:, :, 1:size(tensor(old_xt[2]), 3), :] .= tensor(old_xt[2]) + tensor(Xₜ[3]).indices[1:size(tensor(old_xt[3]).indices, 1), :] .= tensor(old_xt[3]).indices + end + tracker(t, Xₜ, hat) + end + return Xₜ + end + + export bind_gen, reverse_gen +end + + +model = load_model(); + +b = ChainStorm.pdb2batch("path/to/pdb") + +g, recorded = flow_quickgen(ChainStorm.reverse_process(ChainStorm.P), b, ChainStorm.compound_state(b), model, is_reverse = true, smooth=0, progress_bar = ProgressBar) #<- Model inference call + +b = dummy_batch([ChainStorm.lengths_from_chainids(b.chainids); [10]]) +paths = ChainStorm.Tracker() +g = flow_quickgen(ChainStorm.P, b, ChainStorm.zero_state(b), model, tracker = paths, smooth=0, record = recorded, progress_bar = ProgressBar) +id = join(string.(ChainStorm.lengths_from_chainids(b.chainids)),"_")*"-"*join(rand('A':'Z', 4)) +export_pdb("$(id)_bind.pdb", g, b.chainids, b.resinds) #<- Save PDB +samp = gen2prot(g, b.chainids, b.resinds) +animate_trajectory("$(id)_bind.mp4", samp, first_trajectory(paths), viewmode = :fit) \ No newline at end of file diff --git a/src/ChainStorm.jl b/src/ChainStorm.jl index 76c8e6f..3e958aa 100644 --- a/src/ChainStorm.jl +++ b/src/ChainStorm.jl @@ -10,6 +10,29 @@ function load_model(; checkpoint = "ChainStormV1.jld2") return Flux.loadmodel!(ChainStormV1(), JLD2.load(file, "model_state")) end +function pdb2batch(file) + struc = read(file, ProteinStructure) + struc.cluster = 1 + DLProteinFormats.batch_flatrecs([DLProteinFormats.flatten(struc),]) +end + +function lengths_from_chainids(chainids) + counts = Int[] # Initialize an empty array to store counts + current_count = 1 + + for i in 2:length(chainids) + if chainids[i] == chainids[i - 1] + current_count += 1 + else + push!(counts, current_count) + current_count = 1 + end + end + + push!(counts, current_count) + return counts +end + chainids_from_lengths(lengths) = vcat([repeat([i],l) for (i,l) in enumerate(lengths)]...) function gen2prot(samp, chainids, resnums; name = "Gen", ) d = Dict(zip(0:25,'A':'Z')) diff --git a/src/flow.jl b/src/flow.jl index 4c4a315..7238b32 100644 --- a/src/flow.jl +++ b/src/flow.jl @@ -3,6 +3,16 @@ const rotM = Flowfusion.Rotations(3) schedule_f(t) = 1-(1-t)^2 const P = (FProcess(BrownianMotion(0.2f0), schedule_f), FProcess(ManifoldProcess(0.2f0), schedule_f), NoisyInterpolatingDiscreteFlow(0.2f0, K = 2, dummy_token = 21)) +function reverse_process(P) + continuous_schedule = t -> 1 - P[1].F(1 - t) + manifold_schedule = t -> 1 - P[2].F(1 - t) + κ₁ = t -> 1 - P[3].κ₁(1 - t) + dκ₁ = t -> P[3].dκ₁(1 - t) + κ₂ = t -> 1 - P[3].κ₂(1 - t) + dκ₂ = t -> P[3].dκ₁(1 - t) + (Flowfusion.FProcess(P[1].P, continuous_schedule), Flowfusion.FProcess(P[2].P, manifold_schedule), Flowfusion.NoisyInterpolatingDiscreteFlow(κ₁, dκ₁, κ₂, dκ₂, P[3].mask_token)) +end + function compound_state(b) L,B = size(b.aas) cmask = b.aas .< 100 @@ -45,7 +55,6 @@ function flowX1predictor(X0, b, model; d = identity, smooth = 0) prev_trans = values(translation(f)) T = eltype(prev_trans) function m(t, Xt) - print(".") f, aalogits = model(d(t .+ zeros(Float32, 1, batch_dim)), d(Xt), d(b.chainids), d(b.resinds), sc_frames = f) values(translation(f)) .= prev_trans .* T(smooth) .+ values(translation(f)) .* T(1-smooth) prev_trans = values(translation(f)) @@ -54,17 +63,74 @@ function flowX1predictor(X0, b, model; d = identity, smooth = 0) return m end +function flowX0predictor(X0, b, model, P; d = identity, smooth = 0) # Forces P to be a FProcess and doesn't work for some reason for P Deterministic + batch_dim = size(tensor(X0[1]), 4) + ff, _ = model(d(ones(Float32, 1, batch_dim)), d(X0), d(b.chainids), d(b.resinds)) # ones makes it start at time = 1 + if P[1].P isa Deterministic + v = 0 + else + v = P[1].P.v + end + function m(rt, Xt) + ff, aalogits = model(d(1-rt .+ zeros(Float32, 1, batch_dim)), d(Xt), d(b.chainids), d(b.resinds), sc_frames=ff) + X1Hat = deepcopy(ff) + t = 1f0 .- P[1].F.(rt .+ zeros(Float32, 1, batch_dim)) + t[t .>= 0.999] .= 0.999 + values(translation(X1Hat)) .= (tensor(Xt[1]) .- values(translation(X1Hat)) .* t) ./ (1 .- t + v .* t) + M = Xt[2].S.M + p = eachslice(tensor(Xt[2]), dims=(3, 4)) + tangent = -t ./ (1 .- t) .* log.((M,), p, eachslice(values(linear(X1Hat)), dims=(3, 4))) + X0Hat = exp.((M,), p, tangent) + values(linear(X1Hat)) .= stack(X0Hat) + T = eltype(aalogits) + aalogits .= T(-Inf) + aalogits[21,:,:] .= 0 + return (cpu(values(translation(X1Hat))), ManifoldState(rotM, eachslice(cpu(values(linear(X1Hat))), dims=(3,4))), cpu(softmax(aalogits))), (cpu(values(translation(ff))), ManifoldState(rotM, eachslice(cpu(values(linear(ff))), dims=(3,4))), cpu(softmax(aalogits))) + end + return m +end + +function bind_flowX1predictor(X0, b, model, recorded; d = identity, smooth = 0) + batch_dim = size(tensor(X0[1]), 4) + + f, _ = model(d(zeros(Float32, 1, batch_dim)), d(X0), d(b.chainids), d(b.resinds)) + values(translation(f))[:, :, 1:size(tensor(recorded[end][3][1]), 3), :] .= tensor(recorded[end][3][1]) # Might be more sensible to do a weighted average of X̂₁ and (1-t)*(Xₜ₊Δₜ - Xₜ)/Δt + values(linear(f))[:, :, 1:size(tensor(recorded[end][3][2]), 3), :] .= tensor(recorded[end][3][2]) + + f, _ = model(d(zeros(Float32, 1, batch_dim)), d(X0), d(b.chainids), d(b.resinds), sc_frames=f) + values(translation(f))[:, :, 1:size(tensor(recorded[end][3][1]), 3), :] .= tensor(recorded[end][3][1]) + values(linear(f))[:, :, 1:size(tensor(recorded[end][3][2]), 3), :] .= tensor(recorded[end][3][2]) + function m(t, Xt) + f, aalogits = model(d(t .+ zeros(Float32, 1, batch_dim)), d(Xt), d(b.chainids), d(b.resinds), sc_frames = f) + values(translation(f))[:, :, 1:size(tensor(recorded[end][3][1]), 3), :] .= tensor(recorded[end][3][1]) + values(linear(f))[:, :, 1:size(tensor(recorded[end][3][2]), 3), :] .= tensor(recorded[end][3][2]) + + return cpu(values(translation(f))), ManifoldState(rotM, eachslice(cpu(values(linear(f))), dims=(3,4))), cpu(softmax(aalogits)) + end + return m +end + H(a; d = 2/3) = a<=d ? (a^2)/2 : d*(a - d/2) S(a) = H(a)/H(1) -function flow_quickgen(b, model; steps = :default, d = identity, tracker = Returns(nothing), smooth = 0.6) +function flow_quickgen(P, b, X0, model; is_reverse = false, steps = :default, d = identity, tracker = Returns(nothing), smooth = 0.6, record = [], progress_bar=identity) stps = vcat(zeros(5),S.([0.0:0.00255:0.9975;]),[0.999, 0.9998, 1.0]) - if steps isa Number + + if steps isa Number stps = 0f0:1f0/steps:1f0 elseif steps isa AbstractVector stps = steps end - X0 = zero_state(b) - X1pred = flowX1predictor(X0, b, model, d = d, smooth = smooth) - return gen(P, X0, X1pred, Float32.(stps), tracker = tracker) + b.locs .= tensor(X0[1]) + b.aas .= convert(Matrix{Int64}, tensor(X0[3]).indices) + if !is_reverse && length(record) == 0 + X1pred = flowX1predictor(X0, b, model, d = d, smooth = smooth) + return gen(P, X0, X1pred, Float32.(stps), tracker = tracker, progress_bar = progress_bar) + elseif is_reverse + X0pred = flowX0predictor(X0, b, model, P, d = d, smooth = smooth) + return reverse_gen(P, X0, X0pred, Float32.(1 .- reverse(stps)), record, tracker = tracker, progress_bar = progress_bar), record + else + X1pred = bind_flowX1predictor(X0, b, model, record, d = d, smooth = smooth) + return bind_gen(P, X0, X1pred, Float32.(stps), record, tracker = tracker, progress_bar = progress_bar) + end end From a7805fe6078353e4bf131255a7d3e006caa44f81 Mon Sep 17 00:00:00 2001 From: murrellb Date: Sat, 2 Aug 2025 15:29:12 +0000 Subject: [PATCH 2/6] Fixing for GPU, and adding some features --- src/ChainStorm.jl | 5 +- src/extras/reversal.jl | 121 +++++++++++++++++++++++++++++++++++++++++ src/flow.jl | 62 +++++++++++++++------ 3 files changed, 167 insertions(+), 21 deletions(-) create mode 100644 src/extras/reversal.jl diff --git a/src/ChainStorm.jl b/src/ChainStorm.jl index 3e958aa..5ac6c53 100644 --- a/src/ChainStorm.jl +++ b/src/ChainStorm.jl @@ -10,10 +10,9 @@ function load_model(; checkpoint = "ChainStormV1.jld2") return Flux.loadmodel!(ChainStormV1(), JLD2.load(file, "model_state")) end -function pdb2batch(file) - struc = read(file, ProteinStructure) +function pdb2batch(struc::ProteinChains.ProteinStructure) struc.cluster = 1 - DLProteinFormats.batch_flatrecs([DLProteinFormats.flatten(struc),]) + return DLProteinFormats.batch_flatrecs([DLProteinFormats.flatten(struc),]) end function lengths_from_chainids(chainids) diff --git a/src/extras/reversal.jl b/src/extras/reversal.jl new file mode 100644 index 0000000..6edf102 --- /dev/null +++ b/src/extras/reversal.jl @@ -0,0 +1,121 @@ +using Pkg +Pkg.activate(".") +using Revise +Pkg.develop(path="../") + +using Pkg +Pkg.add(["GLMakie", "ProtPlot", "ProgressBars", "CUDA", "cuDNN", "Flux"]) + +using ChainStorm, Flowfusion, GLMakie, ProtPlot, ProgressBars, CUDA, Flux + +@eval Flowfusion begin + function reverse_gen(P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, model, steps::AbstractVector, record; tracker::Function=Returns(nothing), midpoint = false, progress_bar=identity, snap_time = 0) + Xₜ = copy.(X₀) + push!(record, (1, X₀, nothing)) + for (s₁, s₂) in progress_bar(zip(steps, steps[begin+1:end])) + T = eltype(s₁) + s2 = s₂ + s1 = s₁ == s₂ ? s2 - T(0.001) : s₁ + t = midpoint ? (s1 + s2) / 2 : t = s1 + X0hat, X1hat = model(t, Xₜ) + X0hat = resolveprediction(X0hat, Xₜ) + X1hat = resolveprediction(X1hat, Xₜ) + Xₜ = mask(step(P, Xₜ, X0hat, s1, s2), X₀) + if t < snap_time + fakeX1hat = deepcopy(X1hat) + tensor(fakeX1hat[1]) .= tensor(X₀[1]) + tensor(fakeX1hat[2]) .= tensor(X₀[2]) + push!(record, (1-s₂, deepcopy(Xₜ), fakeX1hat)) #records all the steps + else + push!(record, (1-s₂, deepcopy(Xₜ), deepcopy(X1hat))) #records all the steps + end + tracker(1-t, Xₜ, X1hat) + end + return Xₜ + end + function gen(P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, model, steps::AbstractVector; tracker::Function=Returns(nothing), midpoint = false, progress_bar=identity) + Xₜ = copy.(X₀) + for (s₁, s₂) in progress_bar(zip(steps, steps[begin+1:end])) + t = midpoint ? (s₁ + s₂) / 2 : t = s₁ + hat = resolveprediction(model(t, Xₜ), Xₜ) + Xₜ = mask(step(P, Xₜ, hat, s₁, s₂), X₀) + tracker(t, Xₜ, hat) + end + return Xₜ + end + function bind_gen(P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, model, steps::AbstractVector, record; tracker::Function=Returns(nothing), midpoint = false, progress_bar = identity) + Xₜ = copy.(X₀) + for (s₁, s₂) in progress_bar(zip(steps, steps[begin+1:end])) + t = midpoint ? (s₁ + s₂) / 2 : t = s₁ + hat = resolveprediction(model(t, Xₜ), Xₜ) + #Changes xt + s₁, old_xt, _ = record[end] + pop!(record) + tensor(Xₜ[1])[:, :, 1:size(tensor(old_xt[1]), 3), :] .= tensor(old_xt[1]) + tensor(Xₜ[2])[:, :, 1:size(tensor(old_xt[2]), 3), :] .= tensor(old_xt[2]) + tensor(Xₜ[3]).indices[1:size(tensor(old_xt[3]).indices, 1), :] .= tensor(old_xt[3]).indices + Xₜ = mask(step(P, Xₜ, hat, s₁, s₂), X₀) + if length(record) == 1 + tensor(Xₜ[1])[:, :, 1:size(tensor(old_xt[1]), 3), :] .= tensor(old_xt[1]) + tensor(Xₜ[2])[:, :, 1:size(tensor(old_xt[2]), 3), :] .= tensor(old_xt[2]) + tensor(Xₜ[3]).indices[1:size(tensor(old_xt[3]).indices, 1), :] .= tensor(old_xt[3]).indices + end + tracker(t, Xₜ, hat) + end + return Xₜ + end + export bind_gen, reverse_gen +end + + +model = load_model() |> gpu + +struc = pdb"7RBY"1 +target = ChainStorm.pdb2batch(struc[[1]]) + + +@time rev_g, recorded = flow_quickgen(ChainStorm.reverse_process(ChainStorm.P), target, ChainStorm.compound_state(target), model, is_reverse = true, smooth=0, d = gpu, steps = 0f0:0.005f0:1f0, progress_bar = ProgressBar, snap_time = 0.9f0); + +#proportions = [mean(unhot(recorded[i][2][3]).S.state .== 21) for i in 1:length(recorded)] +#pl = Plots.plot(proportions, xlabel = "t step", label = :none, ylabel = "P(21)") +#savefig(pl, "proportions.pdf") + +b = dummy_batch([ChainStorm.lengths_from_chainids(target.chainids); [122, 114]]) +b.resinds[1:length(target.resinds)] .= target.resinds +binder_inds = length(target.resinds)+1:length(b.resinds) +b.resinds[binder_inds] .= 1:length(binder_inds) +X0 = ChainStorm.zero_state(b) +#If you want to bias the starting location: +#Flowfusion.tensor(X0[1])[:,1,binder_inds,1] .*= 0.5f0 +#Flowfusion.tensor(X0[1])[:,1,binder_inds,1] .+= [0.0f0, 0.1f0, 0.1f0] + +paths = ChainStorm.Tracker() +@time fwd_g = flow_quickgen(ChainStorm.P, b, X0, model, tracker = paths, smooth=0, record = deepcopy(recorded), progress_bar = ProgressBar, d = gpu, steps = 0f0:0.005f0:1f0); +id = join(string.(ChainStorm.lengths_from_chainids(b.chainids)),"_")*"-"*join(rand('A':'Z', 4)) +export_pdb("samples/$(id)_bind.pdb", fwd_g, b.chainids, b.resinds) #<- Save PDB + +samp = gen2prot(fwd_g, b.chainids, b.resinds) +animate_trajectory("samples/$(id)_bind.mp4", samp, first_trajectory(paths), viewmode = :fit) + +for _ in 1:10 + snap = rand([0f0, 0.9f0]) + @time rev_g, recorded = flow_quickgen(ChainStorm.reverse_process(ChainStorm.P), target, ChainStorm.compound_state(target), model, is_reverse = true, smooth=0, d = gpu, steps = 0f0:0.005f0:1f0, progress_bar = ProgressBar, snap_time = snap); #, steps = 0f0:0.025f0:1f0); #<- Model inference call + lens = [200+rand(1:25),200+rand(1:25)] + if rand() < 0.33 + lens = [100+rand(1:25),100+rand(1:25)] + end + if rand() < 0.33 + lens = [rand(30:150)] + end + b = dummy_batch([ChainStorm.lengths_from_chainids(target.chainids); lens]) + b.resinds[1:length(target.resinds)] .= target.resinds + binder_inds = length(target.resinds)+1:length(b.resinds) + b.resinds[binder_inds] .= 1:length(binder_inds) + X0 = ChainStorm.zero_state(b) + paths = ChainStorm.Tracker() + @time fwd_g = flow_quickgen(ChainStorm.P, b, X0, model, tracker = paths, smooth=0, record = deepcopy(recorded), progress_bar = ProgressBar, d = gpu, steps = 0f0:0.005f0:1f0); + id = join(string.(ChainStorm.lengths_from_chainids(b.chainids)),"_")*"-"*join(rand('A':'Z', 4)) + export_pdb("samples/$(id)_$(snap)_bind.pdb", fwd_g, b.chainids, b.resinds) #<- Save PDB + samp = gen2prot(fwd_g, b.chainids, b.resinds) + animate_trajectory("samples/$(id)_$(snap)_bind.mp4", samp, first_trajectory(paths), viewmode = :fit) +end \ No newline at end of file diff --git a/src/flow.jl b/src/flow.jl index 7238b32..6ac414b 100644 --- a/src/flow.jl +++ b/src/flow.jl @@ -3,6 +3,20 @@ const rotM = Flowfusion.Rotations(3) schedule_f(t) = 1-(1-t)^2 const P = (FProcess(BrownianMotion(0.2f0), schedule_f), FProcess(ManifoldProcess(0.2f0), schedule_f), NoisyInterpolatingDiscreteFlow(0.2f0, K = 2, dummy_token = 21)) +#Bringing Alexander's version in - this should be replaced by the "full" solution: +function rev_NoisyInterpolatingDiscreteFlow(noise; K = 1, dummy_token::T = nothing) where T + if (K > 1 && isnothing(dummy_token)) + @warn "NoisyInterpolatingDiscreteFlow: If K>1 things might break if your X0 is not the `dummy_token` (which should also be passed to NoisyInterpolatingDiscreteFlow)." + end + return NoisyInterpolatingDiscreteFlow{T}( + t -> oftype(t,1-(1-cos((π/2)*(1-t)))^K), #K1 + t -> oftype(t,(noise * sin(π*t))), #K2 + t -> oftype(t,(K * (π/2) * cos((π/2) * t) * (1 - sin((π/2) * t))^(K - 1))), #dK1 + t -> oftype(t,(noise*π*cos(π*t))), #dK2 + dummy_token + ) +end + function reverse_process(P) continuous_schedule = t -> 1 - P[1].F(1 - t) manifold_schedule = t -> 1 - P[2].F(1 - t) @@ -10,7 +24,9 @@ function reverse_process(P) dκ₁ = t -> P[3].dκ₁(1 - t) κ₂ = t -> 1 - P[3].κ₂(1 - t) dκ₂ = t -> P[3].dκ₁(1 - t) - (Flowfusion.FProcess(P[1].P, continuous_schedule), Flowfusion.FProcess(P[2].P, manifold_schedule), Flowfusion.NoisyInterpolatingDiscreteFlow(κ₁, dκ₁, κ₂, dκ₂, P[3].mask_token)) + #This is a hack fix: + #(Flowfusion.FProcess(P[1].P, continuous_schedule), Flowfusion.FProcess(P[2].P, manifold_schedule), Flowfusion.NoisyInterpolatingDiscreteFlow(κ₁, dκ₁, κ₂, dκ₂, P[3].mask_token)) + (Flowfusion.FProcess(P[1].P, continuous_schedule), Flowfusion.FProcess(P[2].P, manifold_schedule), rev_NoisyInterpolatingDiscreteFlow(0.2f0, K = 2, dummy_token = 21)) end function compound_state(b) @@ -73,7 +89,8 @@ function flowX0predictor(X0, b, model, P; d = identity, smooth = 0) # Forces P t end function m(rt, Xt) ff, aalogits = model(d(1-rt .+ zeros(Float32, 1, batch_dim)), d(Xt), d(b.chainids), d(b.resinds), sc_frames=ff) - X1Hat = deepcopy(ff) + aalogits = deepcopy(cpu(aalogits)) + X1Hat = deepcopy(cpu(ff)) t = 1f0 .- P[1].F.(rt .+ zeros(Float32, 1, batch_dim)) t[t .>= 0.999] .= 0.999 values(translation(X1Hat)) .= (tensor(Xt[1]) .- values(translation(X1Hat)) .* t) ./ (1 .- t + v .* t) @@ -90,21 +107,29 @@ function flowX0predictor(X0, b, model, P; d = identity, smooth = 0) # Forces P t return m end -function bind_flowX1predictor(X0, b, model, recorded; d = identity, smooth = 0) +function bind_flowX1predictor(X0, b, model, recorded; d = identity, smooth = 0, meanshift = true) + recdim = size(tensor(recorded[end][3][1]), 3) batch_dim = size(tensor(X0[1]), 4) - - f, _ = model(d(zeros(Float32, 1, batch_dim)), d(X0), d(b.chainids), d(b.resinds)) - values(translation(f))[:, :, 1:size(tensor(recorded[end][3][1]), 3), :] .= tensor(recorded[end][3][1]) # Might be more sensible to do a weighted average of X̂₁ and (1-t)*(Xₜ₊Δₜ - Xₜ)/Δt - values(linear(f))[:, :, 1:size(tensor(recorded[end][3][2]), 3), :] .= tensor(recorded[end][3][2]) - - f, _ = model(d(zeros(Float32, 1, batch_dim)), d(X0), d(b.chainids), d(b.resinds), sc_frames=f) - values(translation(f))[:, :, 1:size(tensor(recorded[end][3][1]), 3), :] .= tensor(recorded[end][3][1]) - values(linear(f))[:, :, 1:size(tensor(recorded[end][3][2]), 3), :] .= tensor(recorded[end][3][2]) + f, _ = cpu(model(d(zeros(Float32, 1, batch_dim)), d(X0), d(b.chainids), d(b.resinds))) + values(translation(f))[:, :, 1:recdim, :] .= tensor(recorded[end][3][1]) # Might be more sensible to do a weighted average of X̂₁ and (1-t)*(Xₜ₊Δₜ - Xₜ)/Δt + values(linear(f))[:, :, 1:recdim, :] .= tensor(recorded[end][3][2]) + f, _ = cpu(model(d(zeros(Float32, 1, batch_dim)), d(X0), d(b.chainids), d(b.resinds), sc_frames=d(f))) + recmean = Flux.mean(values(translation(f))[:, :, 1:recdim, :], dims = 3) + forcemean = Flux.mean(tensor(recorded[end][3][1]), dims = 3) + values(translation(f))[:, :, 1:recdim, :] .= tensor(recorded[end][3][1]) + if meanshift + values(translation(f))[:, :, 1:recdim, :] .+= forcemean .- recmean + end + values(linear(f))[:, :, 1:recdim, :] .= tensor(recorded[end][3][2]) function m(t, Xt) - f, aalogits = model(d(t .+ zeros(Float32, 1, batch_dim)), d(Xt), d(b.chainids), d(b.resinds), sc_frames = f) - values(translation(f))[:, :, 1:size(tensor(recorded[end][3][1]), 3), :] .= tensor(recorded[end][3][1]) - values(linear(f))[:, :, 1:size(tensor(recorded[end][3][2]), 3), :] .= tensor(recorded[end][3][2]) - + f, aalogits = cpu(model(d(t .+ zeros(Float32, 1, batch_dim)), d(Xt), d(b.chainids), d(b.resinds), sc_frames = d(f))) + recmean = Flux.mean(values(translation(f))[:, :, 1:recdim, :], dims = 3) + forcemean = Flux.mean(tensor(recorded[end][3][1]), dims = 3) + values(translation(f))[:, :, 1:recdim, :] .= tensor(recorded[end][3][1]) + if meanshift #This is to shift the binder over by the amount the target would have shifted, in the other direction: + values(translation(f))[:, :, recdim+1:end, :] .+= forcemean .- recmean + end + values(linear(f))[:, :, 1:recdim, :] .= tensor(recorded[end][3][2]) return cpu(values(translation(f))), ManifoldState(rotM, eachslice(cpu(values(linear(f))), dims=(3,4))), cpu(softmax(aalogits)) end return m @@ -113,7 +138,7 @@ end H(a; d = 2/3) = a<=d ? (a^2)/2 : d*(a - d/2) S(a) = H(a)/H(1) -function flow_quickgen(P, b, X0, model; is_reverse = false, steps = :default, d = identity, tracker = Returns(nothing), smooth = 0.6, record = [], progress_bar=identity) +function flow_quickgen(P, b, X0, model; is_reverse = false, steps = :default, d = identity, tracker = Returns(nothing), smooth = 0.6, record = [], progress_bar=identity, snap_time = 0) stps = vcat(zeros(5),S.([0.0:0.00255:0.9975;]),[0.999, 0.9998, 1.0]) if steps isa Number @@ -122,13 +147,14 @@ function flow_quickgen(P, b, X0, model; is_reverse = false, steps = :default, d stps = steps end b.locs .= tensor(X0[1]) - b.aas .= convert(Matrix{Int64}, tensor(X0[3]).indices) + #b.aas .= convert(Matrix{Int64}, tensor(X0[3]).indices) + b.aas .= unhot(X0[3]).S.state if !is_reverse && length(record) == 0 X1pred = flowX1predictor(X0, b, model, d = d, smooth = smooth) return gen(P, X0, X1pred, Float32.(stps), tracker = tracker, progress_bar = progress_bar) elseif is_reverse X0pred = flowX0predictor(X0, b, model, P, d = d, smooth = smooth) - return reverse_gen(P, X0, X0pred, Float32.(1 .- reverse(stps)), record, tracker = tracker, progress_bar = progress_bar), record + return reverse_gen(P, X0, X0pred, Float32.(1 .- reverse(stps)), record, tracker = tracker, progress_bar = progress_bar, snap_time = snap_time), record else X1pred = bind_flowX1predictor(X0, b, model, record, d = d, smooth = smooth) return bind_gen(P, X0, X1pred, Float32.(stps), record, tracker = tracker, progress_bar = progress_bar) From 32440ed8fd3835490f2576a879e5fc3bfa19ef83 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Sun, 19 Oct 2025 12:44:09 +0000 Subject: [PATCH 3/6] . --- {src/extras => extras}/ChainStormColab.ipynb | 0 {src/extras => extras}/circular.jl | 0 {src/extras => extras}/overwrite_IPA.jl | 0 extras/reversal.jl | 31 +++++ scripts/reverse.jl | 79 ------------ src/ChainStorm.jl | 13 +- src/extras/reversal.jl | 121 ------------------- src/flow.jl | 38 +++--- src/reverse.jl | 62 ++++++++++ 9 files changed, 128 insertions(+), 216 deletions(-) rename {src/extras => extras}/ChainStormColab.ipynb (100%) rename {src/extras => extras}/circular.jl (100%) rename {src/extras => extras}/overwrite_IPA.jl (100%) create mode 100644 extras/reversal.jl delete mode 100644 scripts/reverse.jl delete mode 100644 src/extras/reversal.jl create mode 100644 src/reverse.jl diff --git a/src/extras/ChainStormColab.ipynb b/extras/ChainStormColab.ipynb similarity index 100% rename from src/extras/ChainStormColab.ipynb rename to extras/ChainStormColab.ipynb diff --git a/src/extras/circular.jl b/extras/circular.jl similarity index 100% rename from src/extras/circular.jl rename to extras/circular.jl diff --git a/src/extras/overwrite_IPA.jl b/extras/overwrite_IPA.jl similarity index 100% rename from src/extras/overwrite_IPA.jl rename to extras/overwrite_IPA.jl diff --git a/extras/reversal.jl b/extras/reversal.jl new file mode 100644 index 0000000..51fe274 --- /dev/null +++ b/extras/reversal.jl @@ -0,0 +1,31 @@ +using Pkg +Pkg.activate("reversal", shared=true) +#Pkg.develop(path=".") +#Pkg.add(["CUDA", "cuDNN", "Flux"]) + +using ChainStorm, ChainStorm.Flowfusion, ChainStorm.ProteinChains +using CUDA, Flux + +model = load_model() |> gpu + +struc = pdb"7RBY"1[1:1] +target_length = length(struc[1]) +batch_target = ChainStorm.pdb2batch(struc) + +@time rev_g, recorded = flow_quickgen( + ChainStorm.P, batch_target, ChainStorm.compound_state(batch_target), model; + is_reverse = true, d = gpu, steps = 0f0:0.005f0:1f0, snap_time = 0.9f0); + +new_lengths = [122, 114] +batch = dummy_batch([ChainStorm.lengths_from_chainids(batch_target.chainids); new_lengths]) +batch.resinds .= [batch_target.resinds; 1:new_lengths[1]; 1:new_lengths[2]] +X₀ = ChainStorm.zero_state(batch) + +tracker = ChainStorm.Tracker() +@time fwd_g = flow_quickgen( + ChainStorm.P, batch, X₀, model; + tracker, record = deepcopy(recorded), d = gpu, steps = 0f0:0.005f0:1f0); + +id = join(ChainStorm.lengths_from_chainids(batch.chainids),'_')*"-"*join(rand('A':'Z', 4)) + +export_pdb("$(id)_bind.pdb", fwd_g, batch.chainids, batch.resinds) #<- Save PDB diff --git a/scripts/reverse.jl b/scripts/reverse.jl deleted file mode 100644 index 51c7bd6..0000000 --- a/scripts/reverse.jl +++ /dev/null @@ -1,79 +0,0 @@ -using Pkg -Pkg.add(["GLMakie", "ProtPlot", "ProgressBars"]) - -using ChainStorm, Flowfusion, GLMakie, ProtPlot, ProgressBars - -@eval Flowfusion begin - function reverse_gen(P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, model, steps::AbstractVector, record; tracker::Function=Returns(nothing), midpoint = false, progress_bar=identity) - Xₜ = copy.(X₀) - push!(record, (1, X₀, nothing)) - for (s₁, s₂) in progress_bar(zip(steps, steps[begin+1:end])) - T = eltype(s₁) - s2 = s₂ - s1 = s₁ == s₂ ? s2 - T(0.001) : s₁ - t = midpoint ? (s1 + s2) / 2 : t = s1 - X0hat, X1hat = model(t, Xₜ) - X0hat = resolveprediction(X0hat, Xₜ) - X1hat = resolveprediction(X1hat, Xₜ) - Xₜ = mask(step(P, Xₜ, X0hat, s1, s2), X₀) - - push!(record, (1-s₂, Xₜ, X1hat)) #records all the steps - tracker(1-t, Xₜ, X1hat) - end - return Xₜ - end - function gen(P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, model, steps::AbstractVector; tracker::Function=Returns(nothing), midpoint = false, progress_bar=identity) - Xₜ = copy.(X₀) - for (s₁, s₂) in progress_bar(zip(steps, steps[begin+1:end])) - t = midpoint ? (s₁ + s₂) / 2 : t = s₁ - hat = resolveprediction(model(t, Xₜ), Xₜ) - Xₜ = mask(step(P, Xₜ, hat, s₁, s₂), X₀) - tracker(t, Xₜ, hat) - end - return Xₜ - end - - function bind_gen(P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, model, steps::AbstractVector, record; tracker::Function=Returns(nothing), midpoint = false, progress_bar = identity) - Xₜ = copy.(X₀) - for (s₁, s₂) in progress_bar(zip(steps, steps[begin+1:end])) - t = midpoint ? (s₁ + s₂) / 2 : t = s₁ - - hat = resolveprediction(model(t, Xₜ), Xₜ) - - #Changes xt - s₁, old_xt, _ = record[end] - pop!(record) - tensor(Xₜ[1])[:, :, 1:size(tensor(old_xt[1]), 3), :] .= tensor(old_xt[1]) - tensor(Xₜ[2])[:, :, 1:size(tensor(old_xt[2]), 3), :] .= tensor(old_xt[2]) - tensor(Xₜ[3]).indices[1:size(tensor(old_xt[3]).indices, 1), :] .= tensor(old_xt[3]).indices - - - Xₜ = mask(step(P, Xₜ, hat, s₁, s₂), X₀) - - if length(record) == 1 - tensor(Xₜ[1])[:, :, 1:size(tensor(old_xt[1]), 3), :] .= tensor(old_xt[1]) - tensor(Xₜ[2])[:, :, 1:size(tensor(old_xt[2]), 3), :] .= tensor(old_xt[2]) - tensor(Xₜ[3]).indices[1:size(tensor(old_xt[3]).indices, 1), :] .= tensor(old_xt[3]).indices - end - tracker(t, Xₜ, hat) - end - return Xₜ - end - - export bind_gen, reverse_gen -end - - -model = load_model(); - -b = ChainStorm.pdb2batch("path/to/pdb") - -g, recorded = flow_quickgen(ChainStorm.reverse_process(ChainStorm.P), b, ChainStorm.compound_state(b), model, is_reverse = true, smooth=0, progress_bar = ProgressBar) #<- Model inference call - -b = dummy_batch([ChainStorm.lengths_from_chainids(b.chainids); [10]]) -paths = ChainStorm.Tracker() -g = flow_quickgen(ChainStorm.P, b, ChainStorm.zero_state(b), model, tracker = paths, smooth=0, record = recorded, progress_bar = ProgressBar) -id = join(string.(ChainStorm.lengths_from_chainids(b.chainids)),"_")*"-"*join(rand('A':'Z', 4)) -export_pdb("$(id)_bind.pdb", g, b.chainids, b.resinds) #<- Save PDB -samp = gen2prot(g, b.chainids, b.resinds) -animate_trajectory("$(id)_bind.mp4", samp, first_trajectory(paths), viewmode = :fit) \ No newline at end of file diff --git a/src/ChainStorm.jl b/src/ChainStorm.jl index 5ac6c53..d679408 100644 --- a/src/ChainStorm.jl +++ b/src/ChainStorm.jl @@ -1,7 +1,18 @@ module ChainStorm -using Flowfusion, ForwardBackward, Flux, RandomFeatureMaps, Onion, InvariantPointAttention, BatchedTransformations, ProteinChains, DLProteinFormats, HuggingFaceApi, JLD2 +using Flowfusion +using ForwardBackward +using Flux +using RandomFeatureMaps +using Onion +using InvariantPointAttention +using BatchedTransformations +using ProteinChains +using DLProteinFormats +using HuggingFaceApi +using JLD2 +include("reverse.jl") include("flow.jl") include("model.jl") diff --git a/src/extras/reversal.jl b/src/extras/reversal.jl deleted file mode 100644 index 6edf102..0000000 --- a/src/extras/reversal.jl +++ /dev/null @@ -1,121 +0,0 @@ -using Pkg -Pkg.activate(".") -using Revise -Pkg.develop(path="../") - -using Pkg -Pkg.add(["GLMakie", "ProtPlot", "ProgressBars", "CUDA", "cuDNN", "Flux"]) - -using ChainStorm, Flowfusion, GLMakie, ProtPlot, ProgressBars, CUDA, Flux - -@eval Flowfusion begin - function reverse_gen(P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, model, steps::AbstractVector, record; tracker::Function=Returns(nothing), midpoint = false, progress_bar=identity, snap_time = 0) - Xₜ = copy.(X₀) - push!(record, (1, X₀, nothing)) - for (s₁, s₂) in progress_bar(zip(steps, steps[begin+1:end])) - T = eltype(s₁) - s2 = s₂ - s1 = s₁ == s₂ ? s2 - T(0.001) : s₁ - t = midpoint ? (s1 + s2) / 2 : t = s1 - X0hat, X1hat = model(t, Xₜ) - X0hat = resolveprediction(X0hat, Xₜ) - X1hat = resolveprediction(X1hat, Xₜ) - Xₜ = mask(step(P, Xₜ, X0hat, s1, s2), X₀) - if t < snap_time - fakeX1hat = deepcopy(X1hat) - tensor(fakeX1hat[1]) .= tensor(X₀[1]) - tensor(fakeX1hat[2]) .= tensor(X₀[2]) - push!(record, (1-s₂, deepcopy(Xₜ), fakeX1hat)) #records all the steps - else - push!(record, (1-s₂, deepcopy(Xₜ), deepcopy(X1hat))) #records all the steps - end - tracker(1-t, Xₜ, X1hat) - end - return Xₜ - end - function gen(P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, model, steps::AbstractVector; tracker::Function=Returns(nothing), midpoint = false, progress_bar=identity) - Xₜ = copy.(X₀) - for (s₁, s₂) in progress_bar(zip(steps, steps[begin+1:end])) - t = midpoint ? (s₁ + s₂) / 2 : t = s₁ - hat = resolveprediction(model(t, Xₜ), Xₜ) - Xₜ = mask(step(P, Xₜ, hat, s₁, s₂), X₀) - tracker(t, Xₜ, hat) - end - return Xₜ - end - function bind_gen(P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, model, steps::AbstractVector, record; tracker::Function=Returns(nothing), midpoint = false, progress_bar = identity) - Xₜ = copy.(X₀) - for (s₁, s₂) in progress_bar(zip(steps, steps[begin+1:end])) - t = midpoint ? (s₁ + s₂) / 2 : t = s₁ - hat = resolveprediction(model(t, Xₜ), Xₜ) - #Changes xt - s₁, old_xt, _ = record[end] - pop!(record) - tensor(Xₜ[1])[:, :, 1:size(tensor(old_xt[1]), 3), :] .= tensor(old_xt[1]) - tensor(Xₜ[2])[:, :, 1:size(tensor(old_xt[2]), 3), :] .= tensor(old_xt[2]) - tensor(Xₜ[3]).indices[1:size(tensor(old_xt[3]).indices, 1), :] .= tensor(old_xt[3]).indices - Xₜ = mask(step(P, Xₜ, hat, s₁, s₂), X₀) - if length(record) == 1 - tensor(Xₜ[1])[:, :, 1:size(tensor(old_xt[1]), 3), :] .= tensor(old_xt[1]) - tensor(Xₜ[2])[:, :, 1:size(tensor(old_xt[2]), 3), :] .= tensor(old_xt[2]) - tensor(Xₜ[3]).indices[1:size(tensor(old_xt[3]).indices, 1), :] .= tensor(old_xt[3]).indices - end - tracker(t, Xₜ, hat) - end - return Xₜ - end - export bind_gen, reverse_gen -end - - -model = load_model() |> gpu - -struc = pdb"7RBY"1 -target = ChainStorm.pdb2batch(struc[[1]]) - - -@time rev_g, recorded = flow_quickgen(ChainStorm.reverse_process(ChainStorm.P), target, ChainStorm.compound_state(target), model, is_reverse = true, smooth=0, d = gpu, steps = 0f0:0.005f0:1f0, progress_bar = ProgressBar, snap_time = 0.9f0); - -#proportions = [mean(unhot(recorded[i][2][3]).S.state .== 21) for i in 1:length(recorded)] -#pl = Plots.plot(proportions, xlabel = "t step", label = :none, ylabel = "P(21)") -#savefig(pl, "proportions.pdf") - -b = dummy_batch([ChainStorm.lengths_from_chainids(target.chainids); [122, 114]]) -b.resinds[1:length(target.resinds)] .= target.resinds -binder_inds = length(target.resinds)+1:length(b.resinds) -b.resinds[binder_inds] .= 1:length(binder_inds) -X0 = ChainStorm.zero_state(b) -#If you want to bias the starting location: -#Flowfusion.tensor(X0[1])[:,1,binder_inds,1] .*= 0.5f0 -#Flowfusion.tensor(X0[1])[:,1,binder_inds,1] .+= [0.0f0, 0.1f0, 0.1f0] - -paths = ChainStorm.Tracker() -@time fwd_g = flow_quickgen(ChainStorm.P, b, X0, model, tracker = paths, smooth=0, record = deepcopy(recorded), progress_bar = ProgressBar, d = gpu, steps = 0f0:0.005f0:1f0); -id = join(string.(ChainStorm.lengths_from_chainids(b.chainids)),"_")*"-"*join(rand('A':'Z', 4)) -export_pdb("samples/$(id)_bind.pdb", fwd_g, b.chainids, b.resinds) #<- Save PDB - -samp = gen2prot(fwd_g, b.chainids, b.resinds) -animate_trajectory("samples/$(id)_bind.mp4", samp, first_trajectory(paths), viewmode = :fit) - -for _ in 1:10 - snap = rand([0f0, 0.9f0]) - @time rev_g, recorded = flow_quickgen(ChainStorm.reverse_process(ChainStorm.P), target, ChainStorm.compound_state(target), model, is_reverse = true, smooth=0, d = gpu, steps = 0f0:0.005f0:1f0, progress_bar = ProgressBar, snap_time = snap); #, steps = 0f0:0.025f0:1f0); #<- Model inference call - lens = [200+rand(1:25),200+rand(1:25)] - if rand() < 0.33 - lens = [100+rand(1:25),100+rand(1:25)] - end - if rand() < 0.33 - lens = [rand(30:150)] - end - b = dummy_batch([ChainStorm.lengths_from_chainids(target.chainids); lens]) - b.resinds[1:length(target.resinds)] .= target.resinds - binder_inds = length(target.resinds)+1:length(b.resinds) - b.resinds[binder_inds] .= 1:length(binder_inds) - X0 = ChainStorm.zero_state(b) - paths = ChainStorm.Tracker() - @time fwd_g = flow_quickgen(ChainStorm.P, b, X0, model, tracker = paths, smooth=0, record = deepcopy(recorded), progress_bar = ProgressBar, d = gpu, steps = 0f0:0.005f0:1f0); - id = join(string.(ChainStorm.lengths_from_chainids(b.chainids)),"_")*"-"*join(rand('A':'Z', 4)) - export_pdb("samples/$(id)_$(snap)_bind.pdb", fwd_g, b.chainids, b.resinds) #<- Save PDB - samp = gen2prot(fwd_g, b.chainids, b.resinds) - animate_trajectory("samples/$(id)_$(snap)_bind.mp4", samp, first_trajectory(paths), viewmode = :fit) -end \ No newline at end of file diff --git a/src/flow.jl b/src/flow.jl index 6ac414b..09bd478 100644 --- a/src/flow.jl +++ b/src/flow.jl @@ -42,7 +42,7 @@ function zero_state(b) L,B = size(b.aas) cmask = b.aas .< 100 X0locs = MaskedState(ContinuousState(randn(Float32, size(b.locs))), cmask, b.padmask) - X0rots = MaskedState(ManifoldState(rotM, reshape(Array{Float32}.(Flowfusion.rand(rotM, L*B)), L, B)), cmask, b.padmask) + X0rots = MaskedState(ManifoldState(rotM, reshape(Array{Float32}.(rand(rotM, L*B)), L, B)), cmask, b.padmask) X0aas = MaskedState(DiscreteState(21, Flux.onehotbatch(similar(b.aas) .= 21, 1:21)), cmask, b.padmask) return (X0locs, X0rots, X0aas) end @@ -138,25 +138,33 @@ end H(a; d = 2/3) = a<=d ? (a^2)/2 : d*(a - d/2) S(a) = H(a)/H(1) -function flow_quickgen(P, b, X0, model; is_reverse = false, steps = :default, d = identity, tracker = Returns(nothing), smooth = 0.6, record = [], progress_bar=identity, snap_time = 0) - stps = vcat(zeros(5),S.([0.0:0.00255:0.9975;]),[0.999, 0.9998, 1.0]) - - if steps isa Number - stps = 0f0:1f0/steps:1f0 +function flow_quickgen( + P, b, X0, model; + is_reverse = false, + steps = :default, + d = identity, + smooth = 0, + record = [], + kws... +) + steps = if steps isa Number + 0f0:1f0/steps:1f0 elseif steps isa AbstractVector - stps = steps + steps + else + vcat(zeros(5),S.([0.0:0.00255:0.9975;]),[0.999, 0.9998, 1.0]) end b.locs .= tensor(X0[1]) #b.aas .= convert(Matrix{Int64}, tensor(X0[3]).indices) b.aas .= unhot(X0[3]).S.state - if !is_reverse && length(record) == 0 - X1pred = flowX1predictor(X0, b, model, d = d, smooth = smooth) - return gen(P, X0, X1pred, Float32.(stps), tracker = tracker, progress_bar = progress_bar) - elseif is_reverse - X0pred = flowX0predictor(X0, b, model, P, d = d, smooth = smooth) - return reverse_gen(P, X0, X0pred, Float32.(1 .- reverse(stps)), record, tracker = tracker, progress_bar = progress_bar, snap_time = snap_time), record + if is_reverse + X0pred = flowX0predictor(X0, b, model, P; d, smooth) + return reverse_gen(reverse_process(P), X0, X0pred, Float32.(1 .- reverse(steps)), record; kws...), record + elseif isempty(record) + X1pred = flowX1predictor(X0, b, model; d, smooth) + return gen(P, X0, X1pred, Float32.(steps); kws...) else - X1pred = bind_flowX1predictor(X0, b, model, record, d = d, smooth = smooth) - return bind_gen(P, X0, X1pred, Float32.(stps), record, tracker = tracker, progress_bar = progress_bar) + X1pred = bind_flowX1predictor(X0, b, model, record; d, smooth) + return bind_gen(P, X0, X1pred, Float32.(steps), record; kws...) end end diff --git a/src/reverse.jl b/src/reverse.jl new file mode 100644 index 0000000..033fd60 --- /dev/null +++ b/src/reverse.jl @@ -0,0 +1,62 @@ +using Flowfusion: + resolveprediction, mask, step, + UProcess, UState + +function reverse_gen( + P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, + model, steps::AbstractVector, record; + tracker = Returns(nothing), + midpoint = false, + snap_time = 0, +) + Xₜ = copy.(X₀) + push!(record, (1, X₀, nothing)) + for (s₁, s₂) in zip(steps, steps[begin+1:end]) + T = eltype(s₁) + s₁ = s₁ == s₂ ? s₁ - T(0.001) : s₁ + t = midpoint ? (s₁ + s₂) / 2 : s₁ + X̂₀, X̂₁ = model(t, Xₜ) + X̂₀ = resolveprediction(X̂₀, Xₜ) + X̂₁ = resolveprediction(X̂₁, Xₜ) + Xₜ = mask(step(P, Xₜ, X̂₀, s₁, s₂), X₀) + if t < snap_time + fakeX̂₁ = deepcopy(X̂₁) + tensor(fakeX̂₁[1]) .= tensor(X₀[1]) + tensor(fakeX̂₁[2]) .= tensor(X₀[2]) + push!(record, (1-s₂, deepcopy(Xₜ), fakeX̂₁)) #records all the steps + else + push!(record, (1-s₂, deepcopy(Xₜ), deepcopy(X̂₁))) #records all the steps + end + tracker(1-t, Xₜ, X̂₁) + end + return Xₜ +end + +function bind_gen( + P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, + model, steps::AbstractVector, record; + tracker = Returns(nothing), + midpoint = false, + snap_time = 0, +) + Xₜ = copy.(X₀) + for (s₁, s₂) in zip(steps, steps[begin+1:end]) + t = midpoint ? (s₁ + s₂) / 2 : t = s₁ + X̂₁ = resolveprediction(model(t, Xₜ), Xₜ) + #Changes xt + s₁, old_Xₜ, _ = record[end] + pop!(record) + tensor(Xₜ[1])[:, :, 1:size(tensor(old_Xₜ[1]), 3), :] .= tensor(old_Xₜ[1]) + tensor(Xₜ[2])[:, :, 1:size(tensor(old_Xₜ[2]), 3), :] .= tensor(old_Xₜ[2]) + tensor(Xₜ[3]).indices[1:size(tensor(old_Xₜ[3]).indices, 1), :] .= tensor(old_Xₜ[3]).indices + Xₜ = mask(step(P, Xₜ, X̂₁, s₁, s₂), X₀) + if length(record) == 1 + tensor(Xₜ[1])[:, :, 1:size(tensor(old_Xₜ[1]), 3), :] .= tensor(old_Xₜ[1]) + tensor(Xₜ[2])[:, :, 1:size(tensor(old_Xₜ[2]), 3), :] .= tensor(old_Xₜ[2]) + tensor(Xₜ[3]).indices[1:size(tensor(old_Xₜ[3]).indices, 1), :] .= tensor(old_Xₜ[3]).indices + end + tracker(t, Xₜ, X̂₁) + end + return Xₜ +end + From 732060041800f2192f334654eac50d172c8ef1c4 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Sun, 19 Oct 2025 20:03:41 +0000 Subject: [PATCH 4/6] . --- extras/flex-reversal.jl | 44 +++++++++++++++++++++++++++++ extras/reversal.jl | 24 ++++++++-------- src/flow.jl | 62 +++++++++++++++++++++++++++++++++++++---- src/reverse.jl | 32 +++++++++++++-------- 4 files changed, 133 insertions(+), 29 deletions(-) create mode 100644 extras/flex-reversal.jl diff --git a/extras/flex-reversal.jl b/extras/flex-reversal.jl new file mode 100644 index 0000000..c2715a9 --- /dev/null +++ b/extras/flex-reversal.jl @@ -0,0 +1,44 @@ +using Pkg +Pkg.activate("reversal", shared=true) + +using ChainStorm, ChainStorm.Flowfusion, ChainStorm.ProteinChains +using CUDA, Flux, Random + +model = load_model() |> gpu + +struc = pdb"7RBY"1[1:1] +batch_target = ChainStorm.pdb2batch(struc) +X_target = ChainStorm.compound_state(batch_target) + +phase_points_rev = Float32[1.0, 0.0] +@time _, reverse_records, _ = ChainStorm.flex_gen( + ChainStorm.P, batch_target, X_target, model, phase_points_rev; + step_size = -0.005f0, + record_indices = [1], + d = gpu, + snap_time = 0.9f0, +) +recorded = get(reverse_records, 1, nothing) +recorded === nothing && error("Reverse phase did not produce a recording") +recorded = deepcopy(recorded) + +new_lengths = [122, 114] +batch = dummy_batch([ChainStorm.lengths_from_chainids(batch_target.chainids); new_lengths]) +batch.resinds .= [batch_target.resinds; 1:new_lengths[1]; 1:new_lengths[2]] +X₀ = ChainStorm.zero_state(batch) + +tracker = ChainStorm.Tracker() +phase_points_fw = Float32[0.0, 1.0] +initial_record = (_, _) -> deepcopy(recorded) +@time fwd_state, _, _ = ChainStorm.flex_gen( + ChainStorm.P, batch, X₀, model, phase_points_fw; + step_size = 0.005f0, + record_indices = [1], + initial_recorded = initial_record, + tracker = tracker, + d = gpu, +) + +id = join(ChainStorm.lengths_from_chainids(batch.chainids), '_') * "-" * String(rand('A':'Z', 4)) + +export_pdb("$(id)_bind.pdb", fwd_state, batch.chainids, batch.resinds) diff --git a/extras/reversal.jl b/extras/reversal.jl index 51fe274..04bc252 100644 --- a/extras/reversal.jl +++ b/extras/reversal.jl @@ -6,26 +6,26 @@ Pkg.activate("reversal", shared=true) using ChainStorm, ChainStorm.Flowfusion, ChainStorm.ProteinChains using CUDA, Flux -model = load_model() |> gpu +model = load_model() |> gpu; -struc = pdb"7RBY"1[1:1] -target_length = length(struc[1]) -batch_target = ChainStorm.pdb2batch(struc) +struc = pdb"7RBY"1[1:1]; +target_length = length(struc[1]); +batch_target = ChainStorm.pdb2batch(struc); -@time rev_g, recorded = flow_quickgen( +@time recorded, rev_g = flow_quickgen( ChainStorm.P, batch_target, ChainStorm.compound_state(batch_target), model; - is_reverse = true, d = gpu, steps = 0f0:0.005f0:1f0, snap_time = 0.9f0); + is_reverse = true, d = gpu, steps = 0f0:0.001f0:1f0, snap_time = 0.9f0); -new_lengths = [122, 114] +#=new_lengths = [122, 114] batch = dummy_batch([ChainStorm.lengths_from_chainids(batch_target.chainids); new_lengths]) batch.resinds .= [batch_target.resinds; 1:new_lengths[1]; 1:new_lengths[2]] -X₀ = ChainStorm.zero_state(batch) +X₀ = ChainStorm.zero_state(batch)=# tracker = ChainStorm.Tracker() @time fwd_g = flow_quickgen( - ChainStorm.P, batch, X₀, model; - tracker, record = deepcopy(recorded), d = gpu, steps = 0f0:0.005f0:1f0); + ChainStorm.P, batch_target, rev_g, model; + tracker, d = gpu, steps = 0f0:0.001f0:1f0); -id = join(ChainStorm.lengths_from_chainids(batch.chainids),'_')*"-"*join(rand('A':'Z', 4)) +id = join(ChainStorm.lengths_from_chainids(batch_target.chainids),'_')*"-"*join(rand('A':'Z', 4)) -export_pdb("$(id)_bind.pdb", fwd_g, batch.chainids, batch.resinds) #<- Save PDB +export_pdb("$(id)_bind.pdb", fwd_g, batch_target.chainids, batch_target.resinds) #<- Save PDB diff --git a/src/flow.jl b/src/flow.jl index 09bd478..25fb432 100644 --- a/src/flow.jl +++ b/src/flow.jl @@ -158,13 +158,65 @@ function flow_quickgen( #b.aas .= convert(Matrix{Int64}, tensor(X0[3]).indices) b.aas .= unhot(X0[3]).S.state if is_reverse + println("Reversing") X0pred = flowX0predictor(X0, b, model, P; d, smooth) - return reverse_gen(reverse_process(P), X0, X0pred, Float32.(1 .- reverse(steps)), record; kws...), record - elseif isempty(record) - X1pred = flowX1predictor(X0, b, model; d, smooth) - return gen(P, X0, X1pred, Float32.(steps); kws...) - else + return record, reverse_gen(reverse_process(P), X0, X0pred, Float32.(1 .- reverse(steps)), record; kws...) + elseif !isnothing(record) && !isempty(record) + println("Binding") X1pred = bind_flowX1predictor(X0, b, model, record; d, smooth) return bind_gen(P, X0, X1pred, Float32.(steps), record; kws...) + else + println("Generating") + X1pred = flowX1predictor(X0, b, model; d, smooth) + return gen(P, X0, X1pred, Float32.(steps); kws...) + end +end + +using Flowfusion: lastsize + +function extend(X₀ᵢ, batch, new_lengths) + isnothing(new_lengths) && return X₀ᵢ, batch + new_batch = dummy_batch([lengths_from_chainids(batch.chainids); new_lengths]) + new_batch.resinds .= [batch.resinds; [1:l for l in new_lengths]...] + new_X₀ᵢ = zero_state(new_batch) + s = sum(new_lengths) + tensor(new_X₀ᵢ[1])[:, :, 1:end-s, :] .= tensor(X₀ᵢ[1]) + tensor(new_X₀ᵢ[2])[:, :, 1:end-s, :] .= tensor(X₀ᵢ[2]) + tensor(new_X₀ᵢ[3]).indices[1:end-s, :] .= tensor(X₀ᵢ[3]).indices + return new_X₀ᵢ, new_batch +end + +function flex_quickgen( + P, batch, X₀, model; + step_size = 0.005f0, + phases = [0.0, 1.0], + new_lengths = [nothing, nothing],#fill(nothing, length(phases)-1), + use_record = [false, false],#fill(false, length(phases)-1), + trackers = [Tracker() for _ in 1:length(phases)-1], + kws... +) + X₀ᵢ = X₀ + output = [] + local record + for i in 1:length(phases)-1 + start, stop = phases[i:i+1] + is_reverse = start > stop + a, b = min(start, stop), max(start, stop) + steps = range(Float32(a), Float32(b), step=step_size) + X₀ᵢ, batch = extend(X₀ᵢ, batch, new_lengths[i]) + if is_reverse + record, X₁ᵢ = flow_quickgen( + P, batch, X₀ᵢ, model; + is_reverse, steps, + tracker=trackers[i], kws...) + else + X₁ᵢ = flow_quickgen( + P, batch, X₀ᵢ, model; + record = [], steps, + tracker=trackers[i], kws...) + end + push!(output, (X₁ᵢ, batch, trackers[i])) + X₀ᵢ = X₁ᵢ end + return output end diff --git a/src/reverse.jl b/src/reverse.jl index 033fd60..1809ed4 100644 --- a/src/reverse.jl +++ b/src/reverse.jl @@ -1,10 +1,11 @@ using Flowfusion: resolveprediction, mask, step, - UProcess, UState + tensor, UProcess, UState, unhot function reverse_gen( P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, model, steps::AbstractVector, record; + #recdim = Flowfusion.lastsize(X₀, offset=1), tracker = Returns(nothing), midpoint = false, snap_time = 0, @@ -23,9 +24,17 @@ function reverse_gen( fakeX̂₁ = deepcopy(X̂₁) tensor(fakeX̂₁[1]) .= tensor(X₀[1]) tensor(fakeX̂₁[2]) .= tensor(X₀[2]) - push!(record, (1-s₂, deepcopy(Xₜ), fakeX̂₁)) #records all the steps + push!(record, ( + 1-s₂, + deepcopy(Xₜ),# 1:recdim, offset=1), + fakeX̂₁#, 1:recdim, offset=1) + )) else - push!(record, (1-s₂, deepcopy(Xₜ), deepcopy(X̂₁))) #records all the steps + push!(record, ( + 1-s₂, + deepcopy(Xₜ),# 1:recdim, offset=1), + deepcopy(X̂₁)#, 1:recdim, offset=1) + )) end tracker(1-t, Xₜ, X̂₁) end @@ -37,26 +46,25 @@ function bind_gen( model, steps::AbstractVector, record; tracker = Returns(nothing), midpoint = false, - snap_time = 0, ) Xₜ = copy.(X₀) for (s₁, s₂) in zip(steps, steps[begin+1:end]) - t = midpoint ? (s₁ + s₂) / 2 : t = s₁ + t = midpoint ? (s₁ + s₂) / 2 : s₁ X̂₁ = resolveprediction(model(t, Xₜ), Xₜ) #Changes xt s₁, old_Xₜ, _ = record[end] pop!(record) - tensor(Xₜ[1])[:, :, 1:size(tensor(old_Xₜ[1]), 3), :] .= tensor(old_Xₜ[1]) - tensor(Xₜ[2])[:, :, 1:size(tensor(old_Xₜ[2]), 3), :] .= tensor(old_Xₜ[2]) - tensor(Xₜ[3]).indices[1:size(tensor(old_Xₜ[3]).indices, 1), :] .= tensor(old_Xₜ[3]).indices + old_size = size(tensor(old_Xₜ[1]), 3) + tensor(Xₜ[1])[:, :, 1:old_size, :] .= tensor(old_Xₜ[1]) + tensor(Xₜ[2])[:, :, 1:old_size, :] .= tensor(old_Xₜ[2]) + tensor(Xₜ[3]).indices[1:old_size, :] .= tensor(old_Xₜ[3]).indices Xₜ = mask(step(P, Xₜ, X̂₁, s₁, s₂), X₀) if length(record) == 1 - tensor(Xₜ[1])[:, :, 1:size(tensor(old_Xₜ[1]), 3), :] .= tensor(old_Xₜ[1]) - tensor(Xₜ[2])[:, :, 1:size(tensor(old_Xₜ[2]), 3), :] .= tensor(old_Xₜ[2]) - tensor(Xₜ[3]).indices[1:size(tensor(old_Xₜ[3]).indices, 1), :] .= tensor(old_Xₜ[3]).indices + tensor(Xₜ[1])[:, :, 1:old_size, :] .= tensor(old_Xₜ[1]) + tensor(Xₜ[2])[:, :, 1:old_size, :] .= tensor(old_Xₜ[2]) + tensor(Xₜ[3]).indices[1:old_size, :] .= tensor(old_Xₜ[3]).indices end tracker(t, Xₜ, X̂₁) end return Xₜ end - From 7db4270d5d34ceab5026f70f2b78bae5775e40f9 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Mon, 20 Oct 2025 12:50:12 +0000 Subject: [PATCH 5/6] add simpler reversal interface with flex_quickgen --- .gitignore | 1 + extras/flex-reversal.jl | 44 --------------------------- extras/reversal.jl | 45 +++++++++++++++------------ src/flow.jl | 67 ++++++++++++++++++++++++----------------- src/reverse.jl | 14 +++++---- 5 files changed, 74 insertions(+), 97 deletions(-) delete mode 100644 extras/flex-reversal.jl diff --git a/.gitignore b/.gitignore index 1c02e5e..478db43 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ *.jl.mem Manifest.toml /docs/build/ +movie \ No newline at end of file diff --git a/extras/flex-reversal.jl b/extras/flex-reversal.jl deleted file mode 100644 index c2715a9..0000000 --- a/extras/flex-reversal.jl +++ /dev/null @@ -1,44 +0,0 @@ -using Pkg -Pkg.activate("reversal", shared=true) - -using ChainStorm, ChainStorm.Flowfusion, ChainStorm.ProteinChains -using CUDA, Flux, Random - -model = load_model() |> gpu - -struc = pdb"7RBY"1[1:1] -batch_target = ChainStorm.pdb2batch(struc) -X_target = ChainStorm.compound_state(batch_target) - -phase_points_rev = Float32[1.0, 0.0] -@time _, reverse_records, _ = ChainStorm.flex_gen( - ChainStorm.P, batch_target, X_target, model, phase_points_rev; - step_size = -0.005f0, - record_indices = [1], - d = gpu, - snap_time = 0.9f0, -) -recorded = get(reverse_records, 1, nothing) -recorded === nothing && error("Reverse phase did not produce a recording") -recorded = deepcopy(recorded) - -new_lengths = [122, 114] -batch = dummy_batch([ChainStorm.lengths_from_chainids(batch_target.chainids); new_lengths]) -batch.resinds .= [batch_target.resinds; 1:new_lengths[1]; 1:new_lengths[2]] -X₀ = ChainStorm.zero_state(batch) - -tracker = ChainStorm.Tracker() -phase_points_fw = Float32[0.0, 1.0] -initial_record = (_, _) -> deepcopy(recorded) -@time fwd_state, _, _ = ChainStorm.flex_gen( - ChainStorm.P, batch, X₀, model, phase_points_fw; - step_size = 0.005f0, - record_indices = [1], - initial_recorded = initial_record, - tracker = tracker, - d = gpu, -) - -id = join(ChainStorm.lengths_from_chainids(batch.chainids), '_') * "-" * String(rand('A':'Z', 4)) - -export_pdb("$(id)_bind.pdb", fwd_state, batch.chainids, batch.resinds) diff --git a/extras/reversal.jl b/extras/reversal.jl index 04bc252..ac94b7d 100644 --- a/extras/reversal.jl +++ b/extras/reversal.jl @@ -3,29 +3,34 @@ Pkg.activate("reversal", shared=true) #Pkg.develop(path=".") #Pkg.add(["CUDA", "cuDNN", "Flux"]) -using ChainStorm, ChainStorm.Flowfusion, ChainStorm.ProteinChains +using ChainStorm, Flowfusion, ChainStorm.ProteinChains using CUDA, Flux model = load_model() |> gpu; struc = pdb"7RBY"1[1:1]; -target_length = length(struc[1]); batch_target = ChainStorm.pdb2batch(struc); - -@time recorded, rev_g = flow_quickgen( - ChainStorm.P, batch_target, ChainStorm.compound_state(batch_target), model; - is_reverse = true, d = gpu, steps = 0f0:0.001f0:1f0, snap_time = 0.9f0); - -#=new_lengths = [122, 114] -batch = dummy_batch([ChainStorm.lengths_from_chainids(batch_target.chainids); new_lengths]) -batch.resinds .= [batch_target.resinds; 1:new_lengths[1]; 1:new_lengths[2]] -X₀ = ChainStorm.zero_state(batch)=# - -tracker = ChainStorm.Tracker() -@time fwd_g = flow_quickgen( - ChainStorm.P, batch_target, rev_g, model; - tracker, d = gpu, steps = 0f0:0.001f0:1f0); - -id = join(ChainStorm.lengths_from_chainids(batch_target.chainids),'_')*"-"*join(rand('A':'Z', 4)) - -export_pdb("$(id)_bind.pdb", fwd_g, batch_target.chainids, batch_target.resinds) #<- Save PDB +X0 = compound_state(batch_target) + +phases = [ + Phase(1.0 => 0.0), + Phase(0.0 => 1.0, use_record=true, new_lengths=[100, 120]), + Phase(1.0 => 0.5, record_dim=195), + Phase(0.5 => 1.0, use_record=true), + Phase(1.0 => 0.75, record_dim=195), + Phase(0.75 => 1.0, use_record=true), +] + +out = flex_quickgen(P, batch_target, X0, model; d=gpu); + +dir = "movie" +isdir(dir) || mkdir(dir) +frame_index = 0; for (i, (phase, (X1, b, tracker))) in enumerate(zip(phases, out)) + for (j, Xₜ) in enumerate(tracker.xt) + frame_index += 1 + export_pdb( + start, stop = phase.interval, + "$dir/$frame_index-$i-$j-$start-$stop.pdb", + Xₜ, b.chainids, b.resinds) + end +end diff --git a/src/flow.jl b/src/flow.jl index 25fb432..622febe 100644 --- a/src/flow.jl +++ b/src/flow.jl @@ -147,6 +147,7 @@ function flow_quickgen( record = [], kws... ) + b = deepcopy(b) steps = if steps isa Number 0f0:1f0/steps:1f0 elseif steps isa AbstractVector @@ -155,18 +156,17 @@ function flow_quickgen( vcat(zeros(5),S.([0.0:0.00255:0.9975;]),[0.999, 0.9998, 1.0]) end b.locs .= tensor(X0[1]) - #b.aas .= convert(Matrix{Int64}, tensor(X0[3]).indices) b.aas .= unhot(X0[3]).S.state if is_reverse - println("Reversing") + @info "Reversing" X0pred = flowX0predictor(X0, b, model, P; d, smooth) - return record, reverse_gen(reverse_process(P), X0, X0pred, Float32.(1 .- reverse(steps)), record; kws...) + return record, reverse_gen(P, X0, X0pred, Float32.(1 .- reverse(steps)), record; kws...) elseif !isnothing(record) && !isempty(record) - println("Binding") + @info "Binding" X1pred = bind_flowX1predictor(X0, b, model, record; d, smooth) return bind_gen(P, X0, X1pred, Float32.(steps), record; kws...) else - println("Generating") + @info "Generating" X1pred = flowX1predictor(X0, b, model; d, smooth) return gen(P, X0, X1pred, Float32.(steps); kws...) end @@ -174,49 +174,62 @@ end using Flowfusion: lastsize -function extend(X₀ᵢ, batch, new_lengths) - isnothing(new_lengths) && return X₀ᵢ, batch +function add_residues(X₀, batch, new_lengths) + isnothing(new_lengths) && return X₀, batch new_batch = dummy_batch([lengths_from_chainids(batch.chainids); new_lengths]) new_batch.resinds .= [batch.resinds; [1:l for l in new_lengths]...] - new_X₀ᵢ = zero_state(new_batch) + new_X₀ = zero_state(new_batch) s = sum(new_lengths) - tensor(new_X₀ᵢ[1])[:, :, 1:end-s, :] .= tensor(X₀ᵢ[1]) - tensor(new_X₀ᵢ[2])[:, :, 1:end-s, :] .= tensor(X₀ᵢ[2]) - tensor(new_X₀ᵢ[3]).indices[1:end-s, :] .= tensor(X₀ᵢ[3]).indices - return new_X₀ᵢ, new_batch + tensor(new_X₀[1])[:, :, 1:end-s, :] .= tensor(X₀[1]) + tensor(new_X₀[2])[:, :, 1:end-s, :] .= tensor(X₀[2]) + tensor(new_X₀[3]).indices[1:end-s, :] .= tensor(X₀[3]).indices + println("Added $(new_lengths) residues") + return new_X₀, new_batch end +@kwdef struct Phase + interval::Pair{Float32,Float32} + step_size::Float32 = 0.005 + new_lengths::Union{Nothing,Vector{Int}} = nothing + record_dim::Union{Nothing,Int} = nothing + use_record::Bool = false + snap_time::Float32 = 0.9 + tracker = Tracker() +end + +Phase(interval::Pair; kws...) = Phase(; interval, kws...) + function flex_quickgen( P, batch, X₀, model; - step_size = 0.005f0, - phases = [0.0, 1.0], - new_lengths = [nothing, nothing],#fill(nothing, length(phases)-1), - use_record = [false, false],#fill(false, length(phases)-1), - trackers = [Tracker() for _ in 1:length(phases)-1], + phases = [Phase(0.0 => 1.0)], kws... ) X₀ᵢ = X₀ output = [] local record - for i in 1:length(phases)-1 - start, stop = phases[i:i+1] + for phase in phases + start, stop = phase.interval is_reverse = start > stop a, b = min(start, stop), max(start, stop) - steps = range(Float32(a), Float32(b), step=step_size) - X₀ᵢ, batch = extend(X₀ᵢ, batch, new_lengths[i]) - if is_reverse + steps = range(Float32(a), Float32(b), step=phase.step_size) + X₀ᵢ, batch = add_residues(X₀ᵢ, batch, phase.new_lengths) + @time if is_reverse record, X₁ᵢ = flow_quickgen( - P, batch, X₀ᵢ, model; + reverse_process(P), batch, X₀ᵢ, model; is_reverse, steps, - tracker=trackers[i], kws...) + recdim = phase.record_dim, + phase.tracker, kws...) else + bind_kws = phase.use_record ? (; record, phase.snap_time) : (;) X₁ᵢ = flow_quickgen( P, batch, X₀ᵢ, model; - record = [], steps, - tracker=trackers[i], kws...) + steps, phase.tracker, + bind_kws..., kws...) end - push!(output, (X₁ᵢ, batch, trackers[i])) + push!(output, (X₁ᵢ, batch, phase.tracker)) X₀ᵢ = X₁ᵢ end return output end + +export flex_quickgen, Phase, compound_state \ No newline at end of file diff --git a/src/reverse.jl b/src/reverse.jl index 1809ed4..81b2202 100644 --- a/src/reverse.jl +++ b/src/reverse.jl @@ -1,15 +1,17 @@ using Flowfusion: resolveprediction, mask, step, - tensor, UProcess, UState, unhot + tensor, UProcess, UState, unhot, + selectlastdim, lastsize function reverse_gen( P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, model, steps::AbstractVector, record; - #recdim = Flowfusion.lastsize(X₀, offset=1), + recdim = nothing, tracker = Returns(nothing), midpoint = false, snap_time = 0, ) + recdim = isnothing(recdim) ? lastsize(X₀, offset=1) : recdim Xₜ = copy.(X₀) push!(record, (1, X₀, nothing)) for (s₁, s₂) in zip(steps, steps[begin+1:end]) @@ -26,14 +28,14 @@ function reverse_gen( tensor(fakeX̂₁[2]) .= tensor(X₀[2]) push!(record, ( 1-s₂, - deepcopy(Xₜ),# 1:recdim, offset=1), - fakeX̂₁#, 1:recdim, offset=1) + selectlastdim(deepcopy(Xₜ), 1:recdim, offset=1), + selectlastdim(fakeX̂₁, 1:recdim, offset=1) )) else push!(record, ( 1-s₂, - deepcopy(Xₜ),# 1:recdim, offset=1), - deepcopy(X̂₁)#, 1:recdim, offset=1) + selectlastdim(deepcopy(Xₜ), 1:recdim, offset=1), + selectlastdim(deepcopy(X̂₁), 1:recdim, offset=1) )) end tracker(1-t, Xₜ, X̂₁) From 1af421e55ffbbf86cc7581ed88ece82284839128 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Mon, 20 Oct 2025 13:19:05 +0000 Subject: [PATCH 6/6] use correct phases --- extras/reversal.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extras/reversal.jl b/extras/reversal.jl index ac94b7d..c4aaff6 100644 --- a/extras/reversal.jl +++ b/extras/reversal.jl @@ -21,7 +21,7 @@ phases = [ Phase(0.75 => 1.0, use_record=true), ] -out = flex_quickgen(P, batch_target, X0, model; d=gpu); +out = flex_quickgen(P, batch_target, X0, model; phases, d=gpu); dir = "movie" isdir(dir) || mkdir(dir)