diff --git a/.gitignore b/.gitignore index 1c02e5e..478db43 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ *.jl.mem Manifest.toml /docs/build/ +movie \ No newline at end of file diff --git a/README.md b/README.md index b7680c9..246c9fc 100644 --- a/README.md +++ b/README.md @@ -112,3 +112,5 @@ for epoch in 1:100 jldsave("model_epoch_$epoch.jld", model_state = Flux.state(cpu(model)), opt_state=cpu(opt_state)) end ``` + +The code written __will__ break the other examples. diff --git a/src/extras/ChainStormColab.ipynb b/extras/ChainStormColab.ipynb similarity index 100% rename from src/extras/ChainStormColab.ipynb rename to extras/ChainStormColab.ipynb diff --git a/src/extras/circular.jl b/extras/circular.jl similarity index 100% rename from src/extras/circular.jl rename to extras/circular.jl diff --git a/src/extras/overwrite_IPA.jl b/extras/overwrite_IPA.jl similarity index 100% rename from src/extras/overwrite_IPA.jl rename to extras/overwrite_IPA.jl diff --git a/extras/reversal.jl b/extras/reversal.jl new file mode 100644 index 0000000..c4aaff6 --- /dev/null +++ b/extras/reversal.jl @@ -0,0 +1,36 @@ +using Pkg +Pkg.activate("reversal", shared=true) +#Pkg.develop(path=".") +#Pkg.add(["CUDA", "cuDNN", "Flux"]) + +using ChainStorm, Flowfusion, ChainStorm.ProteinChains +using CUDA, Flux + +model = load_model() |> gpu; + +struc = pdb"7RBY"1[1:1]; +batch_target = ChainStorm.pdb2batch(struc); +X0 = compound_state(batch_target) + +phases = [ + Phase(1.0 => 0.0), + Phase(0.0 => 1.0, use_record=true, new_lengths=[100, 120]), + Phase(1.0 => 0.5, record_dim=195), + Phase(0.5 => 1.0, use_record=true), + Phase(1.0 => 0.75, record_dim=195), + Phase(0.75 => 1.0, use_record=true), +] + +out = flex_quickgen(P, batch_target, X0, model; phases, d=gpu); + +dir = "movie" +isdir(dir) || mkdir(dir) +frame_index = 0; for (i, (phase, (X1, b, tracker))) in enumerate(zip(phases, out)) + for (j, Xₜ) in enumerate(tracker.xt) + frame_index += 1 + export_pdb( + start, stop = phase.interval, + "$dir/$frame_index-$i-$j-$start-$stop.pdb", + Xₜ, b.chainids, b.resinds) + end +end diff --git a/src/ChainStorm.jl b/src/ChainStorm.jl index 76c8e6f..d679408 100644 --- a/src/ChainStorm.jl +++ b/src/ChainStorm.jl @@ -1,7 +1,18 @@ module ChainStorm -using Flowfusion, ForwardBackward, Flux, RandomFeatureMaps, Onion, InvariantPointAttention, BatchedTransformations, ProteinChains, DLProteinFormats, HuggingFaceApi, JLD2 +using Flowfusion +using ForwardBackward +using Flux +using RandomFeatureMaps +using Onion +using InvariantPointAttention +using BatchedTransformations +using ProteinChains +using DLProteinFormats +using HuggingFaceApi +using JLD2 +include("reverse.jl") include("flow.jl") include("model.jl") @@ -10,6 +21,28 @@ function load_model(; checkpoint = "ChainStormV1.jld2") return Flux.loadmodel!(ChainStormV1(), JLD2.load(file, "model_state")) end +function pdb2batch(struc::ProteinChains.ProteinStructure) + struc.cluster = 1 + return DLProteinFormats.batch_flatrecs([DLProteinFormats.flatten(struc),]) +end + +function lengths_from_chainids(chainids) + counts = Int[] # Initialize an empty array to store counts + current_count = 1 + + for i in 2:length(chainids) + if chainids[i] == chainids[i - 1] + current_count += 1 + else + push!(counts, current_count) + current_count = 1 + end + end + + push!(counts, current_count) + return counts +end + chainids_from_lengths(lengths) = vcat([repeat([i],l) for (i,l) in enumerate(lengths)]...) function gen2prot(samp, chainids, resnums; name = "Gen", ) d = Dict(zip(0:25,'A':'Z')) diff --git a/src/flow.jl b/src/flow.jl index 4c4a315..622febe 100644 --- a/src/flow.jl +++ b/src/flow.jl @@ -3,6 +3,32 @@ const rotM = Flowfusion.Rotations(3) schedule_f(t) = 1-(1-t)^2 const P = (FProcess(BrownianMotion(0.2f0), schedule_f), FProcess(ManifoldProcess(0.2f0), schedule_f), NoisyInterpolatingDiscreteFlow(0.2f0, K = 2, dummy_token = 21)) +#Bringing Alexander's version in - this should be replaced by the "full" solution: +function rev_NoisyInterpolatingDiscreteFlow(noise; K = 1, dummy_token::T = nothing) where T + if (K > 1 && isnothing(dummy_token)) + @warn "NoisyInterpolatingDiscreteFlow: If K>1 things might break if your X0 is not the `dummy_token` (which should also be passed to NoisyInterpolatingDiscreteFlow)." + end + return NoisyInterpolatingDiscreteFlow{T}( + t -> oftype(t,1-(1-cos((π/2)*(1-t)))^K), #K1 + t -> oftype(t,(noise * sin(π*t))), #K2 + t -> oftype(t,(K * (π/2) * cos((π/2) * t) * (1 - sin((π/2) * t))^(K - 1))), #dK1 + t -> oftype(t,(noise*π*cos(π*t))), #dK2 + dummy_token + ) +end + +function reverse_process(P) + continuous_schedule = t -> 1 - P[1].F(1 - t) + manifold_schedule = t -> 1 - P[2].F(1 - t) + κ₁ = t -> 1 - P[3].κ₁(1 - t) + dκ₁ = t -> P[3].dκ₁(1 - t) + κ₂ = t -> 1 - P[3].κ₂(1 - t) + dκ₂ = t -> P[3].dκ₁(1 - t) + #This is a hack fix: + #(Flowfusion.FProcess(P[1].P, continuous_schedule), Flowfusion.FProcess(P[2].P, manifold_schedule), Flowfusion.NoisyInterpolatingDiscreteFlow(κ₁, dκ₁, κ₂, dκ₂, P[3].mask_token)) + (Flowfusion.FProcess(P[1].P, continuous_schedule), Flowfusion.FProcess(P[2].P, manifold_schedule), rev_NoisyInterpolatingDiscreteFlow(0.2f0, K = 2, dummy_token = 21)) +end + function compound_state(b) L,B = size(b.aas) cmask = b.aas .< 100 @@ -16,7 +42,7 @@ function zero_state(b) L,B = size(b.aas) cmask = b.aas .< 100 X0locs = MaskedState(ContinuousState(randn(Float32, size(b.locs))), cmask, b.padmask) - X0rots = MaskedState(ManifoldState(rotM, reshape(Array{Float32}.(Flowfusion.rand(rotM, L*B)), L, B)), cmask, b.padmask) + X0rots = MaskedState(ManifoldState(rotM, reshape(Array{Float32}.(rand(rotM, L*B)), L, B)), cmask, b.padmask) X0aas = MaskedState(DiscreteState(21, Flux.onehotbatch(similar(b.aas) .= 21, 1:21)), cmask, b.padmask) return (X0locs, X0rots, X0aas) end @@ -45,7 +71,6 @@ function flowX1predictor(X0, b, model; d = identity, smooth = 0) prev_trans = values(translation(f)) T = eltype(prev_trans) function m(t, Xt) - print(".") f, aalogits = model(d(t .+ zeros(Float32, 1, batch_dim)), d(Xt), d(b.chainids), d(b.resinds), sc_frames = f) values(translation(f)) .= prev_trans .* T(smooth) .+ values(translation(f)) .* T(1-smooth) prev_trans = values(translation(f)) @@ -54,17 +79,157 @@ function flowX1predictor(X0, b, model; d = identity, smooth = 0) return m end +function flowX0predictor(X0, b, model, P; d = identity, smooth = 0) # Forces P to be a FProcess and doesn't work for some reason for P Deterministic + batch_dim = size(tensor(X0[1]), 4) + ff, _ = model(d(ones(Float32, 1, batch_dim)), d(X0), d(b.chainids), d(b.resinds)) # ones makes it start at time = 1 + if P[1].P isa Deterministic + v = 0 + else + v = P[1].P.v + end + function m(rt, Xt) + ff, aalogits = model(d(1-rt .+ zeros(Float32, 1, batch_dim)), d(Xt), d(b.chainids), d(b.resinds), sc_frames=ff) + aalogits = deepcopy(cpu(aalogits)) + X1Hat = deepcopy(cpu(ff)) + t = 1f0 .- P[1].F.(rt .+ zeros(Float32, 1, batch_dim)) + t[t .>= 0.999] .= 0.999 + values(translation(X1Hat)) .= (tensor(Xt[1]) .- values(translation(X1Hat)) .* t) ./ (1 .- t + v .* t) + M = Xt[2].S.M + p = eachslice(tensor(Xt[2]), dims=(3, 4)) + tangent = -t ./ (1 .- t) .* log.((M,), p, eachslice(values(linear(X1Hat)), dims=(3, 4))) + X0Hat = exp.((M,), p, tangent) + values(linear(X1Hat)) .= stack(X0Hat) + T = eltype(aalogits) + aalogits .= T(-Inf) + aalogits[21,:,:] .= 0 + return (cpu(values(translation(X1Hat))), ManifoldState(rotM, eachslice(cpu(values(linear(X1Hat))), dims=(3,4))), cpu(softmax(aalogits))), (cpu(values(translation(ff))), ManifoldState(rotM, eachslice(cpu(values(linear(ff))), dims=(3,4))), cpu(softmax(aalogits))) + end + return m +end + +function bind_flowX1predictor(X0, b, model, recorded; d = identity, smooth = 0, meanshift = true) + recdim = size(tensor(recorded[end][3][1]), 3) + batch_dim = size(tensor(X0[1]), 4) + f, _ = cpu(model(d(zeros(Float32, 1, batch_dim)), d(X0), d(b.chainids), d(b.resinds))) + values(translation(f))[:, :, 1:recdim, :] .= tensor(recorded[end][3][1]) # Might be more sensible to do a weighted average of X̂₁ and (1-t)*(Xₜ₊Δₜ - Xₜ)/Δt + values(linear(f))[:, :, 1:recdim, :] .= tensor(recorded[end][3][2]) + f, _ = cpu(model(d(zeros(Float32, 1, batch_dim)), d(X0), d(b.chainids), d(b.resinds), sc_frames=d(f))) + recmean = Flux.mean(values(translation(f))[:, :, 1:recdim, :], dims = 3) + forcemean = Flux.mean(tensor(recorded[end][3][1]), dims = 3) + values(translation(f))[:, :, 1:recdim, :] .= tensor(recorded[end][3][1]) + if meanshift + values(translation(f))[:, :, 1:recdim, :] .+= forcemean .- recmean + end + values(linear(f))[:, :, 1:recdim, :] .= tensor(recorded[end][3][2]) + function m(t, Xt) + f, aalogits = cpu(model(d(t .+ zeros(Float32, 1, batch_dim)), d(Xt), d(b.chainids), d(b.resinds), sc_frames = d(f))) + recmean = Flux.mean(values(translation(f))[:, :, 1:recdim, :], dims = 3) + forcemean = Flux.mean(tensor(recorded[end][3][1]), dims = 3) + values(translation(f))[:, :, 1:recdim, :] .= tensor(recorded[end][3][1]) + if meanshift #This is to shift the binder over by the amount the target would have shifted, in the other direction: + values(translation(f))[:, :, recdim+1:end, :] .+= forcemean .- recmean + end + values(linear(f))[:, :, 1:recdim, :] .= tensor(recorded[end][3][2]) + return cpu(values(translation(f))), ManifoldState(rotM, eachslice(cpu(values(linear(f))), dims=(3,4))), cpu(softmax(aalogits)) + end + return m +end + H(a; d = 2/3) = a<=d ? (a^2)/2 : d*(a - d/2) S(a) = H(a)/H(1) -function flow_quickgen(b, model; steps = :default, d = identity, tracker = Returns(nothing), smooth = 0.6) - stps = vcat(zeros(5),S.([0.0:0.00255:0.9975;]),[0.999, 0.9998, 1.0]) - if steps isa Number - stps = 0f0:1f0/steps:1f0 +function flow_quickgen( + P, b, X0, model; + is_reverse = false, + steps = :default, + d = identity, + smooth = 0, + record = [], + kws... +) + b = deepcopy(b) + steps = if steps isa Number + 0f0:1f0/steps:1f0 elseif steps isa AbstractVector - stps = steps + steps + else + vcat(zeros(5),S.([0.0:0.00255:0.9975;]),[0.999, 0.9998, 1.0]) end - X0 = zero_state(b) - X1pred = flowX1predictor(X0, b, model, d = d, smooth = smooth) - return gen(P, X0, X1pred, Float32.(stps), tracker = tracker) + b.locs .= tensor(X0[1]) + b.aas .= unhot(X0[3]).S.state + if is_reverse + @info "Reversing" + X0pred = flowX0predictor(X0, b, model, P; d, smooth) + return record, reverse_gen(P, X0, X0pred, Float32.(1 .- reverse(steps)), record; kws...) + elseif !isnothing(record) && !isempty(record) + @info "Binding" + X1pred = bind_flowX1predictor(X0, b, model, record; d, smooth) + return bind_gen(P, X0, X1pred, Float32.(steps), record; kws...) + else + @info "Generating" + X1pred = flowX1predictor(X0, b, model; d, smooth) + return gen(P, X0, X1pred, Float32.(steps); kws...) + end +end + +using Flowfusion: lastsize + +function add_residues(X₀, batch, new_lengths) + isnothing(new_lengths) && return X₀, batch + new_batch = dummy_batch([lengths_from_chainids(batch.chainids); new_lengths]) + new_batch.resinds .= [batch.resinds; [1:l for l in new_lengths]...] + new_X₀ = zero_state(new_batch) + s = sum(new_lengths) + tensor(new_X₀[1])[:, :, 1:end-s, :] .= tensor(X₀[1]) + tensor(new_X₀[2])[:, :, 1:end-s, :] .= tensor(X₀[2]) + tensor(new_X₀[3]).indices[1:end-s, :] .= tensor(X₀[3]).indices + println("Added $(new_lengths) residues") + return new_X₀, new_batch end + +@kwdef struct Phase + interval::Pair{Float32,Float32} + step_size::Float32 = 0.005 + new_lengths::Union{Nothing,Vector{Int}} = nothing + record_dim::Union{Nothing,Int} = nothing + use_record::Bool = false + snap_time::Float32 = 0.9 + tracker = Tracker() +end + +Phase(interval::Pair; kws...) = Phase(; interval, kws...) + +function flex_quickgen( + P, batch, X₀, model; + phases = [Phase(0.0 => 1.0)], + kws... +) + X₀ᵢ = X₀ + output = [] + local record + for phase in phases + start, stop = phase.interval + is_reverse = start > stop + a, b = min(start, stop), max(start, stop) + steps = range(Float32(a), Float32(b), step=phase.step_size) + X₀ᵢ, batch = add_residues(X₀ᵢ, batch, phase.new_lengths) + @time if is_reverse + record, X₁ᵢ = flow_quickgen( + reverse_process(P), batch, X₀ᵢ, model; + is_reverse, steps, + recdim = phase.record_dim, + phase.tracker, kws...) + else + bind_kws = phase.use_record ? (; record, phase.snap_time) : (;) + X₁ᵢ = flow_quickgen( + P, batch, X₀ᵢ, model; + steps, phase.tracker, + bind_kws..., kws...) + end + push!(output, (X₁ᵢ, batch, phase.tracker)) + X₀ᵢ = X₁ᵢ + end + return output +end + +export flex_quickgen, Phase, compound_state \ No newline at end of file diff --git a/src/reverse.jl b/src/reverse.jl new file mode 100644 index 0000000..81b2202 --- /dev/null +++ b/src/reverse.jl @@ -0,0 +1,72 @@ +using Flowfusion: + resolveprediction, mask, step, + tensor, UProcess, UState, unhot, + selectlastdim, lastsize + +function reverse_gen( + P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, + model, steps::AbstractVector, record; + recdim = nothing, + tracker = Returns(nothing), + midpoint = false, + snap_time = 0, +) + recdim = isnothing(recdim) ? lastsize(X₀, offset=1) : recdim + Xₜ = copy.(X₀) + push!(record, (1, X₀, nothing)) + for (s₁, s₂) in zip(steps, steps[begin+1:end]) + T = eltype(s₁) + s₁ = s₁ == s₂ ? s₁ - T(0.001) : s₁ + t = midpoint ? (s₁ + s₂) / 2 : s₁ + X̂₀, X̂₁ = model(t, Xₜ) + X̂₀ = resolveprediction(X̂₀, Xₜ) + X̂₁ = resolveprediction(X̂₁, Xₜ) + Xₜ = mask(step(P, Xₜ, X̂₀, s₁, s₂), X₀) + if t < snap_time + fakeX̂₁ = deepcopy(X̂₁) + tensor(fakeX̂₁[1]) .= tensor(X₀[1]) + tensor(fakeX̂₁[2]) .= tensor(X₀[2]) + push!(record, ( + 1-s₂, + selectlastdim(deepcopy(Xₜ), 1:recdim, offset=1), + selectlastdim(fakeX̂₁, 1:recdim, offset=1) + )) + else + push!(record, ( + 1-s₂, + selectlastdim(deepcopy(Xₜ), 1:recdim, offset=1), + selectlastdim(deepcopy(X̂₁), 1:recdim, offset=1) + )) + end + tracker(1-t, Xₜ, X̂₁) + end + return Xₜ +end + +function bind_gen( + P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, + model, steps::AbstractVector, record; + tracker = Returns(nothing), + midpoint = false, +) + Xₜ = copy.(X₀) + for (s₁, s₂) in zip(steps, steps[begin+1:end]) + t = midpoint ? (s₁ + s₂) / 2 : s₁ + X̂₁ = resolveprediction(model(t, Xₜ), Xₜ) + #Changes xt + s₁, old_Xₜ, _ = record[end] + pop!(record) + old_size = size(tensor(old_Xₜ[1]), 3) + tensor(Xₜ[1])[:, :, 1:old_size, :] .= tensor(old_Xₜ[1]) + tensor(Xₜ[2])[:, :, 1:old_size, :] .= tensor(old_Xₜ[2]) + tensor(Xₜ[3]).indices[1:old_size, :] .= tensor(old_Xₜ[3]).indices + Xₜ = mask(step(P, Xₜ, X̂₁, s₁, s₂), X₀) + if length(record) == 1 + tensor(Xₜ[1])[:, :, 1:old_size, :] .= tensor(old_Xₜ[1]) + tensor(Xₜ[2])[:, :, 1:old_size, :] .= tensor(old_Xₜ[2]) + tensor(Xₜ[3]).indices[1:old_size, :] .= tensor(old_Xₜ[3]).indices + end + tracker(t, Xₜ, X̂₁) + end + return Xₜ +end