Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
*.jl.mem
Manifest.toml
/docs/build/
movie
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,5 @@ for epoch in 1:100
jldsave("model_epoch_$epoch.jld", model_state = Flux.state(cpu(model)), opt_state=cpu(opt_state))
end
```

The code written __will__ break the other examples.
File renamed without changes.
File renamed without changes.
File renamed without changes.
36 changes: 36 additions & 0 deletions extras/reversal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using Pkg
Pkg.activate("reversal", shared=true)
#Pkg.develop(path=".")
#Pkg.add(["CUDA", "cuDNN", "Flux"])

using ChainStorm, Flowfusion, ChainStorm.ProteinChains
using CUDA, Flux

model = load_model() |> gpu;

struc = pdb"7RBY"1[1:1];
batch_target = ChainStorm.pdb2batch(struc);
X0 = compound_state(batch_target)

phases = [
Phase(1.0 => 0.0),
Phase(0.0 => 1.0, use_record=true, new_lengths=[100, 120]),
Phase(1.0 => 0.5, record_dim=195),
Phase(0.5 => 1.0, use_record=true),
Phase(1.0 => 0.75, record_dim=195),
Phase(0.75 => 1.0, use_record=true),
]

out = flex_quickgen(P, batch_target, X0, model; phases, d=gpu);

dir = "movie"
isdir(dir) || mkdir(dir)
frame_index = 0; for (i, (phase, (X1, b, tracker))) in enumerate(zip(phases, out))
for (j, Xₜ) in enumerate(tracker.xt)
frame_index += 1
export_pdb(
start, stop = phase.interval,
"$dir/$frame_index-$i-$j-$start-$stop.pdb",
Xₜ, b.chainids, b.resinds)
end
end
35 changes: 34 additions & 1 deletion src/ChainStorm.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
module ChainStorm

using Flowfusion, ForwardBackward, Flux, RandomFeatureMaps, Onion, InvariantPointAttention, BatchedTransformations, ProteinChains, DLProteinFormats, HuggingFaceApi, JLD2
using Flowfusion
using ForwardBackward
using Flux
using RandomFeatureMaps
using Onion
using InvariantPointAttention
using BatchedTransformations
using ProteinChains
using DLProteinFormats
using HuggingFaceApi
using JLD2

include("reverse.jl")
include("flow.jl")
include("model.jl")

Expand All @@ -10,6 +21,28 @@ function load_model(; checkpoint = "ChainStormV1.jld2")
return Flux.loadmodel!(ChainStormV1(), JLD2.load(file, "model_state"))
end

function pdb2batch(struc::ProteinChains.ProteinStructure)
struc.cluster = 1
return DLProteinFormats.batch_flatrecs([DLProteinFormats.flatten(struc),])
end

function lengths_from_chainids(chainids)
counts = Int[] # Initialize an empty array to store counts
current_count = 1

for i in 2:length(chainids)
if chainids[i] == chainids[i - 1]
current_count += 1
else
push!(counts, current_count)
current_count = 1
end
end

push!(counts, current_count)
return counts
end

chainids_from_lengths(lengths) = vcat([repeat([i],l) for (i,l) in enumerate(lengths)]...)
function gen2prot(samp, chainids, resnums; name = "Gen", )
d = Dict(zip(0:25,'A':'Z'))
Expand Down
185 changes: 175 additions & 10 deletions src/flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,32 @@ const rotM = Flowfusion.Rotations(3)
schedule_f(t) = 1-(1-t)^2
const P = (FProcess(BrownianMotion(0.2f0), schedule_f), FProcess(ManifoldProcess(0.2f0), schedule_f), NoisyInterpolatingDiscreteFlow(0.2f0, K = 2, dummy_token = 21))

#Bringing Alexander's version in - this should be replaced by the "full" solution:
function rev_NoisyInterpolatingDiscreteFlow(noise; K = 1, dummy_token::T = nothing) where T
if (K > 1 && isnothing(dummy_token))
@warn "NoisyInterpolatingDiscreteFlow: If K>1 things might break if your X0 is not the `dummy_token` (which should also be passed to NoisyInterpolatingDiscreteFlow)."
end
return NoisyInterpolatingDiscreteFlow{T}(
t -> oftype(t,1-(1-cos((π/2)*(1-t)))^K), #K1
t -> oftype(t,(noise * sin(π*t))), #K2
t -> oftype(t,(K * (π/2) * cos((π/2) * t) * (1 - sin((π/2) * t))^(K - 1))), #dK1
t -> oftype(t,(noise*π*cos(π*t))), #dK2
dummy_token
)
end

function reverse_process(P)
continuous_schedule = t -> 1 - P[1].F(1 - t)
manifold_schedule = t -> 1 - P[2].F(1 - t)
κ₁ = t -> 1 - P[3].κ₁(1 - t)
dκ₁ = t -> P[3].dκ₁(1 - t)
κ₂ = t -> 1 - P[3].κ₂(1 - t)
dκ₂ = t -> P[3].dκ₁(1 - t)
#This is a hack fix:
#(Flowfusion.FProcess(P[1].P, continuous_schedule), Flowfusion.FProcess(P[2].P, manifold_schedule), Flowfusion.NoisyInterpolatingDiscreteFlow(κ₁, dκ₁, κ₂, dκ₂, P[3].mask_token))
(Flowfusion.FProcess(P[1].P, continuous_schedule), Flowfusion.FProcess(P[2].P, manifold_schedule), rev_NoisyInterpolatingDiscreteFlow(0.2f0, K = 2, dummy_token = 21))
end

function compound_state(b)
L,B = size(b.aas)
cmask = b.aas .< 100
Expand All @@ -16,7 +42,7 @@ function zero_state(b)
L,B = size(b.aas)
cmask = b.aas .< 100
X0locs = MaskedState(ContinuousState(randn(Float32, size(b.locs))), cmask, b.padmask)
X0rots = MaskedState(ManifoldState(rotM, reshape(Array{Float32}.(Flowfusion.rand(rotM, L*B)), L, B)), cmask, b.padmask)
X0rots = MaskedState(ManifoldState(rotM, reshape(Array{Float32}.(rand(rotM, L*B)), L, B)), cmask, b.padmask)
X0aas = MaskedState(DiscreteState(21, Flux.onehotbatch(similar(b.aas) .= 21, 1:21)), cmask, b.padmask)
return (X0locs, X0rots, X0aas)
end
Expand Down Expand Up @@ -45,7 +71,6 @@ function flowX1predictor(X0, b, model; d = identity, smooth = 0)
prev_trans = values(translation(f))
T = eltype(prev_trans)
function m(t, Xt)
print(".")
f, aalogits = model(d(t .+ zeros(Float32, 1, batch_dim)), d(Xt), d(b.chainids), d(b.resinds), sc_frames = f)
values(translation(f)) .= prev_trans .* T(smooth) .+ values(translation(f)) .* T(1-smooth)
prev_trans = values(translation(f))
Expand All @@ -54,17 +79,157 @@ function flowX1predictor(X0, b, model; d = identity, smooth = 0)
return m
end

function flowX0predictor(X0, b, model, P; d = identity, smooth = 0) # Forces P to be a FProcess and doesn't work for some reason for P Deterministic
batch_dim = size(tensor(X0[1]), 4)
ff, _ = model(d(ones(Float32, 1, batch_dim)), d(X0), d(b.chainids), d(b.resinds)) # ones makes it start at time = 1
if P[1].P isa Deterministic
v = 0
else
v = P[1].P.v
end
function m(rt, Xt)
ff, aalogits = model(d(1-rt .+ zeros(Float32, 1, batch_dim)), d(Xt), d(b.chainids), d(b.resinds), sc_frames=ff)
aalogits = deepcopy(cpu(aalogits))
X1Hat = deepcopy(cpu(ff))
t = 1f0 .- P[1].F.(rt .+ zeros(Float32, 1, batch_dim))
t[t .>= 0.999] .= 0.999
values(translation(X1Hat)) .= (tensor(Xt[1]) .- values(translation(X1Hat)) .* t) ./ (1 .- t + v .* t)
M = Xt[2].S.M
p = eachslice(tensor(Xt[2]), dims=(3, 4))
tangent = -t ./ (1 .- t) .* log.((M,), p, eachslice(values(linear(X1Hat)), dims=(3, 4)))
X0Hat = exp.((M,), p, tangent)
values(linear(X1Hat)) .= stack(X0Hat)
T = eltype(aalogits)
aalogits .= T(-Inf)
aalogits[21,:,:] .= 0
return (cpu(values(translation(X1Hat))), ManifoldState(rotM, eachslice(cpu(values(linear(X1Hat))), dims=(3,4))), cpu(softmax(aalogits))), (cpu(values(translation(ff))), ManifoldState(rotM, eachslice(cpu(values(linear(ff))), dims=(3,4))), cpu(softmax(aalogits)))
end
return m
end

function bind_flowX1predictor(X0, b, model, recorded; d = identity, smooth = 0, meanshift = true)
recdim = size(tensor(recorded[end][3][1]), 3)
batch_dim = size(tensor(X0[1]), 4)
f, _ = cpu(model(d(zeros(Float32, 1, batch_dim)), d(X0), d(b.chainids), d(b.resinds)))
values(translation(f))[:, :, 1:recdim, :] .= tensor(recorded[end][3][1]) # Might be more sensible to do a weighted average of X̂₁ and (1-t)*(Xₜ₊Δₜ - Xₜ)/Δt
values(linear(f))[:, :, 1:recdim, :] .= tensor(recorded[end][3][2])
f, _ = cpu(model(d(zeros(Float32, 1, batch_dim)), d(X0), d(b.chainids), d(b.resinds), sc_frames=d(f)))
recmean = Flux.mean(values(translation(f))[:, :, 1:recdim, :], dims = 3)
forcemean = Flux.mean(tensor(recorded[end][3][1]), dims = 3)
values(translation(f))[:, :, 1:recdim, :] .= tensor(recorded[end][3][1])
if meanshift
values(translation(f))[:, :, 1:recdim, :] .+= forcemean .- recmean
end
values(linear(f))[:, :, 1:recdim, :] .= tensor(recorded[end][3][2])
function m(t, Xt)
f, aalogits = cpu(model(d(t .+ zeros(Float32, 1, batch_dim)), d(Xt), d(b.chainids), d(b.resinds), sc_frames = d(f)))
recmean = Flux.mean(values(translation(f))[:, :, 1:recdim, :], dims = 3)
forcemean = Flux.mean(tensor(recorded[end][3][1]), dims = 3)
values(translation(f))[:, :, 1:recdim, :] .= tensor(recorded[end][3][1])
if meanshift #This is to shift the binder over by the amount the target would have shifted, in the other direction:
values(translation(f))[:, :, recdim+1:end, :] .+= forcemean .- recmean
end
values(linear(f))[:, :, 1:recdim, :] .= tensor(recorded[end][3][2])
return cpu(values(translation(f))), ManifoldState(rotM, eachslice(cpu(values(linear(f))), dims=(3,4))), cpu(softmax(aalogits))
end
return m
end

H(a; d = 2/3) = a<=d ? (a^2)/2 : d*(a - d/2)
S(a) = H(a)/H(1)

function flow_quickgen(b, model; steps = :default, d = identity, tracker = Returns(nothing), smooth = 0.6)
stps = vcat(zeros(5),S.([0.0:0.00255:0.9975;]),[0.999, 0.9998, 1.0])
if steps isa Number
stps = 0f0:1f0/steps:1f0
function flow_quickgen(
P, b, X0, model;
is_reverse = false,
steps = :default,
d = identity,
smooth = 0,
record = [],
kws...
)
b = deepcopy(b)
steps = if steps isa Number
0f0:1f0/steps:1f0
elseif steps isa AbstractVector
stps = steps
steps
else
vcat(zeros(5),S.([0.0:0.00255:0.9975;]),[0.999, 0.9998, 1.0])
end
X0 = zero_state(b)
X1pred = flowX1predictor(X0, b, model, d = d, smooth = smooth)
return gen(P, X0, X1pred, Float32.(stps), tracker = tracker)
b.locs .= tensor(X0[1])
b.aas .= unhot(X0[3]).S.state
if is_reverse
@info "Reversing"
X0pred = flowX0predictor(X0, b, model, P; d, smooth)
return record, reverse_gen(P, X0, X0pred, Float32.(1 .- reverse(steps)), record; kws...)
elseif !isnothing(record) && !isempty(record)
@info "Binding"
X1pred = bind_flowX1predictor(X0, b, model, record; d, smooth)
return bind_gen(P, X0, X1pred, Float32.(steps), record; kws...)
else
@info "Generating"
X1pred = flowX1predictor(X0, b, model; d, smooth)
return gen(P, X0, X1pred, Float32.(steps); kws...)
end
end

using Flowfusion: lastsize

function add_residues(X₀, batch, new_lengths)
isnothing(new_lengths) && return X₀, batch
new_batch = dummy_batch([lengths_from_chainids(batch.chainids); new_lengths])
new_batch.resinds .= [batch.resinds; [1:l for l in new_lengths]...]
new_X₀ = zero_state(new_batch)
s = sum(new_lengths)
tensor(new_X₀[1])[:, :, 1:end-s, :] .= tensor(X₀[1])
tensor(new_X₀[2])[:, :, 1:end-s, :] .= tensor(X₀[2])
tensor(new_X₀[3]).indices[1:end-s, :] .= tensor(X₀[3]).indices
println("Added $(new_lengths) residues")
return new_X₀, new_batch
end

@kwdef struct Phase
interval::Pair{Float32,Float32}
step_size::Float32 = 0.005
new_lengths::Union{Nothing,Vector{Int}} = nothing
record_dim::Union{Nothing,Int} = nothing
use_record::Bool = false
snap_time::Float32 = 0.9
tracker = Tracker()
end

Phase(interval::Pair; kws...) = Phase(; interval, kws...)

function flex_quickgen(
P, batch, X₀, model;
phases = [Phase(0.0 => 1.0)],
kws...
)
X₀ᵢ = X₀
output = []
local record
for phase in phases
start, stop = phase.interval
is_reverse = start > stop
a, b = min(start, stop), max(start, stop)
steps = range(Float32(a), Float32(b), step=phase.step_size)
X₀ᵢ, batch = add_residues(X₀ᵢ, batch, phase.new_lengths)
@time if is_reverse
record, X₁ᵢ = flow_quickgen(
reverse_process(P), batch, X₀ᵢ, model;
is_reverse, steps,
recdim = phase.record_dim,
phase.tracker, kws...)
else
bind_kws = phase.use_record ? (; record, phase.snap_time) : (;)
X₁ᵢ = flow_quickgen(
P, batch, X₀ᵢ, model;
steps, phase.tracker,
bind_kws..., kws...)
end
push!(output, (X₁ᵢ, batch, phase.tracker))
X₀ᵢ = X₁ᵢ
end
return output
end

export flex_quickgen, Phase, compound_state
72 changes: 72 additions & 0 deletions src/reverse.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
using Flowfusion:
resolveprediction, mask, step,
tensor, UProcess, UState, unhot,
selectlastdim, lastsize

function reverse_gen(
P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}},
model, steps::AbstractVector, record;
recdim = nothing,
tracker = Returns(nothing),
midpoint = false,
snap_time = 0,
)
recdim = isnothing(recdim) ? lastsize(X₀, offset=1) : recdim
Xₜ = copy.(X₀)
push!(record, (1, X₀, nothing))
for (s₁, s₂) in zip(steps, steps[begin+1:end])
T = eltype(s₁)
s₁ = s₁ == s₂ ? s₁ - T(0.001) : s₁
t = midpoint ? (s₁ + s₂) / 2 : s₁
X̂₀, X̂₁ = model(t, Xₜ)
X̂₀ = resolveprediction(X̂₀, Xₜ)
X̂₁ = resolveprediction(X̂₁, Xₜ)
Xₜ = mask(step(P, Xₜ, X̂₀, s₁, s₂), X₀)
if t < snap_time
fakeX̂₁ = deepcopy(X̂₁)
tensor(fakeX̂₁[1]) .= tensor(X₀[1])
tensor(fakeX̂₁[2]) .= tensor(X₀[2])
push!(record, (
1-s₂,
selectlastdim(deepcopy(Xₜ), 1:recdim, offset=1),
selectlastdim(fakeX̂₁, 1:recdim, offset=1)
))
else
push!(record, (
1-s₂,
selectlastdim(deepcopy(Xₜ), 1:recdim, offset=1),
selectlastdim(deepcopy(X̂₁), 1:recdim, offset=1)
))
end
tracker(1-t, Xₜ, X̂₁)
end
return Xₜ
end

function bind_gen(
P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}},
model, steps::AbstractVector, record;
tracker = Returns(nothing),
midpoint = false,
)
Xₜ = copy.(X₀)
for (s₁, s₂) in zip(steps, steps[begin+1:end])
t = midpoint ? (s₁ + s₂) / 2 : s₁
X̂₁ = resolveprediction(model(t, Xₜ), Xₜ)
#Changes xt
s₁, old_Xₜ, _ = record[end]
pop!(record)
old_size = size(tensor(old_Xₜ[1]), 3)
tensor(Xₜ[1])[:, :, 1:old_size, :] .= tensor(old_Xₜ[1])
tensor(Xₜ[2])[:, :, 1:old_size, :] .= tensor(old_Xₜ[2])
tensor(Xₜ[3]).indices[1:old_size, :] .= tensor(old_Xₜ[3]).indices
Xₜ = mask(step(P, Xₜ, X̂₁, s₁, s₂), X₀)
if length(record) == 1
tensor(Xₜ[1])[:, :, 1:old_size, :] .= tensor(old_Xₜ[1])
tensor(Xₜ[2])[:, :, 1:old_size, :] .= tensor(old_Xₜ[2])
tensor(Xₜ[3]).indices[1:old_size, :] .= tensor(old_Xₜ[3]).indices
end
tracker(t, Xₜ, X̂₁)
end
return Xₜ
end