Skip to content

Commit cfdd91d

Browse files
Merge pull request #70 from TheDisorderedOrganization/fix_io
Load multiple chains
2 parents 86bf131 + 5fbdad6 commit cfdd91d

4 files changed

Lines changed: 105 additions & 78 deletions

File tree

src/IO/IO.jl

Lines changed: 68 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ function load_configuration(io, format::Arianna.Format; m=1)
6868
end
6969
if "pos" in keys(column_info)
7070
pos_d, pos_index = column_info["pos"]
71-
position = Vector{SVector{pos_d, Float64}}(undef, N)
71+
position = Vector{SVector{pos_d,Float64}}(undef, N)
7272
else
7373
missing_key_error("pos")
7474
end
@@ -85,12 +85,12 @@ function load_configuration(io, format::Arianna.Format; m=1)
8585
end
8686
position[i] = SVector{pos_d}(parse.(Float64, split_line[pos_index:pos_index+pos_d-1]))
8787
end
88-
config_dict = Dict( :N => N,
89-
:d => pos_d,
90-
:box => box,
91-
:species => species,
92-
:position => position,
93-
:metadata => metadata
88+
config_dict = Dict(:N => N,
89+
:d => pos_d,
90+
:box => box,
91+
:species => species,
92+
:position => position,
93+
:metadata => metadata
9494
)
9595
if bool_molecule
9696
config_dict[:molecule] = molecule
@@ -133,22 +133,22 @@ function get_model(data, i::Int, j::Int)
133133
rcut = get(m, "rcut", nothing)
134134

135135
return GeneralKG(m["epsilon"], m["sigma"], m["k"], m["r0"];
136-
filter_kwargs(
137-
:rcut => get(m, "rcut", nothing),
138-
:ϵbond => get(m, "epsilonbond", nothing),
139-
:σbond => get(m, "sigmabond", nothing),
140-
:rcutbond => get(m, "rcutbond", nothing),
141-
)...)
136+
filter_kwargs(
137+
:rcut => get(m, "rcut", nothing),
138+
:ϵbond => get(m, "epsilonbond", nothing),
139+
:σbond => get(m, "sigmabond", nothing),
140+
:rcutbond => get(m, "rcutbond", nothing),
141+
)...)
142142
elseif m["name"] == "SmoothLennardJones"
143143
return SmoothLennardJones(m["epsilon"], m["sigma"];
144-
filter_kwargs(
145-
:rcut => get(m, "rcut", nothing))...)
144+
filter_kwargs(
145+
:rcut => get(m, "rcut", nothing))...)
146146
elseif m["name"] == "LennardJones"
147147
return LennardJones(m["epsilon"], m["sigma"];
148-
filter_kwargs(
149-
:rcut => get(m, "rcut", nothing),
150-
:shift_potential => get(m, "shift_potential", true),
151-
)...)
148+
filter_kwargs(
149+
:rcut => get(m, "rcut", nothing),
150+
:shift_potential => get(m, "shift_potential", true),
151+
)...)
152152
else
153153
error("Model $(m["name"]) is not implemented")
154154
return nothing
@@ -181,11 +181,11 @@ function read_bonds(data, N, format::Arianna.Format)
181181
row_bonds = get_row_bonds(selrow, N, format)
182182
bond = [Vector{Int}() for _ in 1:N]
183183
for i in 1:N_bonds
184-
atom_i, atom_j = parse.(Int, split(bonds_data[row_bonds + i], " ")[bond_index:bond_index+1])
184+
atom_i, atom_j = parse.(Int, split(bonds_data[row_bonds+i], " ")[bond_index:bond_index+1])
185185
push!(bond[atom_i], atom_j)
186186
push!(bond[atom_j], atom_i)
187187
if bool_btype
188-
btype_ij = parse.(Int, split(bonds_data[row_bonds + i], " ")[btype_index])
188+
btype_ij = parse.(Int, split(bonds_data[row_bonds+i], " ")[btype_index])
189189
else
190190
btype_ij = 1
191191
end
@@ -207,57 +207,81 @@ function broadcast_dict(dicts, key)
207207
return [dict[key] for dict in dicts]
208208
end
209209

210-
function load_chains(init_path; args=Dict(), verbose=false)
210+
function load_chains(init_path; args=Dict(), filename="", verbose=false)
211211
input_files = Vector{String}()
212212
if isfile(init_path)
213213
push!(input_files, init_path)
214214
elseif isdir(init_path)
215215
for (root, dirs, files) in walkdir(init_path)
216216
for file in files
217-
push!(input_files, joinpath(root, file))
217+
if occursin(filename, file)
218+
push!(input_files, joinpath(root, file))
219+
end
218220
end
219221
end
220222
end
221223
verbose && println("Processing $(length(input_files)) configuration file(s)")
222224
verbose && @show input_files
225+
223226
config_dict = load_configuration.(input_files)
224227
initial_species_array = broadcast_dict(config_dict, :species)
225-
initial_position_array = broadcast_dict(config_dict, :position)
226-
initial_box_array = broadcast_dict(config_dict, :box)
228+
initial_position_array = broadcast_dict(config_dict, :position)
229+
initial_box_array = broadcast_dict(config_dict, :box)
227230
metadata_array = broadcast_dict(config_dict, :metadata)
231+
228232
N, d = config_dict[1][:N], config_dict[1][:d]
229233
@assert all(isequal(N), length.(initial_position_array))
230234
@assert all(isequal(d), vcat([length.(X) for X in initial_position_array]...))
235+
231236
initial_density_array = length.(initial_position_array) ./ prod.(initial_box_array)
232-
if length(metadata_array) > 1
233-
initial_temperature_array = [parse(Float64, split(filter(x -> occursin("T:", x), metadata)[1], ":")[2]) for metadata in metadata_array]
234-
input_models = [split(filter(x -> occursin("model:", x), metadata)[1], ":")[2] for metadata in metadata_array]
235-
@assert all(isequal(input_models[1]), input_models)
237+
238+
has_temp = all(m -> any(x -> occursin("T:", x), m), metadata_array)
239+
has_model = all(m -> any(x -> occursin("model:", x), m), metadata_array)
240+
241+
if length(metadata_array) 1 && has_temp
242+
initial_temperature_array = [parse(Float64, split(filter(x -> occursin("T:", x), m)[1], ":")[2]) for m in metadata_array]
236243
else
237244
initial_temperature_array = nothing
245+
end
246+
247+
if length(metadata_array) 1 && has_model
248+
input_models = [split(filter(x -> occursin("model:", x), m)[1], ":")[2] for m in metadata_array]
249+
@assert all(isequal(input_models[1]), input_models)
250+
else
238251
input_models = nothing
239252
end
240-
# Update density, temperature and model if needed
253+
254+
# Update density if needed
241255
if haskey(args, "density") && !isnothing(args["density"])
242256
λs = (initial_density_array ./ args["density"]) .^ (1 / d)
243257
initial_density_array .= args["density"]
244258
initial_position_array .= [X .* λ for (X, λ) in zip(initial_position_array, λs)]
245259
initial_box_array .= [box .* λ for (box, λ) in zip(initial_box_array, λs)]
246260
end
261+
262+
# Safely overriding arrays without type or dimension conflicts
247263
if haskey(args, "temperature") && !isnothing(args["temperature"])
248-
initial_temperature_array = args["temperature"]
264+
if args["temperature"] isa AbstractVector
265+
initial_temperature_array = args["temperature"]
266+
else
267+
initial_temperature_array = fill(args["temperature"], length(input_files))
268+
end
249269
elseif isnothing(initial_temperature_array)
250270
missing_key_error("temperature")
251271
end
252-
if haskey(args, "model") && !isnothing(args["model"])
253-
input_models = args["model"]
254272

273+
if haskey(args, "model") && !isnothing(args["model"])
274+
if args["model"] isa AbstractVector
275+
input_models = args["model"]
276+
else
277+
input_models = fill(args["model"], length(input_files))
278+
end
255279
elseif isnothing(input_models)
256280
missing_key_error("model")
257281
end
282+
258283
# Fold back into the box
259284
initial_position_array .= [[fold_back(x, box) for x in X] for (X, box) in zip(initial_position_array, initial_box_array)]
260-
# Parse model
261285

262286
# Copy configurations nsim times (replicas)
263287
if haskey(args, "nsim") && !isnothing(args["nsim"]) && args["nsim"] > 1
@@ -268,40 +292,43 @@ function load_chains(init_path; args=Dict(), verbose=false)
268292
initial_density_array = vcat([[copy(x) for _ in 1:nsim] for x in initial_density_array]...)
269293
initial_temperature_array = vcat([[copy(x) for _ in 1:nsim] for x in initial_temperature_array]...)
270294
end
271-
# Handle cell list (this is classy)
295+
296+
# Parse model
272297
available_species = unique(vcat(initial_species_array...))
273298
n_species = length(available_species)
274299
if input_models[1] isa Dict
275-
model_matrix = SMatrix{n_species, n_species}([get_model(input_models[1], i, j) for i in 1:n_species, j in 1:n_species])
300+
model_matrix = SMatrix{n_species,n_species}([get_model(input_models[1], i, j) for i in 1:n_species, j in 1:n_species])
276301
elseif occursin(r"\(", input_models[1]) && occursin(r"\)", input_models[1])
277-
model_matrix = eval(Meta.parse(input_models[1])) # Parse the string if it has parentheses
302+
model_matrix = eval(Meta.parse(input_models[1]))
278303
else
279-
model_matrix = eval(Meta.parse(input_models[1] * "()")) # Else, append () and evaluate
304+
model_matrix = eval(Meta.parse(input_models[1] * "()"))
280305
end
281306
@assert isa(model_matrix, AbstractArray)
282307

283308
maxcut = maximum([m.rcut for m in model_matrix])
284309
Z = mean(initial_density_array) * volume_sphere(maxcut, d)
285310
list_type = Z / N < 0.1 ? LinkedList : EmptyList
311+
286312
if haskey(args, "list_type") && !isnothing(args["list_type"])
287313
list_type = eval(Meta.parse(args["list_type"]))
288314
end
315+
289316
list_parameters = get(args, "list_parameters", nothing)
290317
verbose && println("Using $list_type as cell list type")
291-
# Create independent chains
318+
319+
# Create independent chains (Preserves V2 System constraints)
292320
bool_molecule = :molecule in keys(config_dict[1])
293321
if bool_molecule
294322
initial_molecule_array = broadcast_dict(config_dict, :molecule)
295323
initial_bond_array = broadcast_dict(config_dict, :bond)
296-
#initial_btype_array = broadcast_dict(config_dict, :btype)
297324
chains = [System(initial_position_array[k], initial_species_array[k], initial_molecule_array[k], initial_density_array[k], initial_temperature_array[k], model_matrix, initial_bond_array[k], list_type=list_type, list_parameters=list_parameters) for k in eachindex(initial_position_array)]
298325
else
299326
chains = [System(initial_position_array[k], initial_species_array[k], initial_density_array[k], initial_temperature_array[k], model_matrix, list_type=list_type, list_parameters=list_parameters) for k in eachindex(initial_position_array)]
300327
end
328+
301329
verbose && println("$(length(chains)) chains created")
302330
return chains
303331
end
304-
305332
function formatted_string(num::Real, digits::Integer)
306333
fmtstr = "%." * string(digits) * "f"
307334
fmt = Printf.Format(fmtstr)

src/IO/exyz.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,18 @@ end
77

88
function parse_column_string(column_str::AbstractString, ::EXYZ)
99
columns = split(column_str, ":")
10-
column_info = OrderedDict{String, Vector}() # Use OrderedDict to maintain order
10+
column_info = OrderedDict{String,Vector}() # Use OrderedDict to maintain order
1111
i, index = 1, 1
1212
types = ["S", "I", "R"]
1313
while i <= length(columns)
1414
if i + 2 <= length(columns) && (columns[i+1] types)
15-
column_name = columns[i]
16-
dimension = parse(Int, columns[i + 2])
17-
column_info[column_name] = [dimension, index]
18-
index += dimension
19-
i += 3 # Skip data type and dimension
15+
column_name = columns[i]
16+
dimension = parse(Int, columns[i+2])
17+
column_info[column_name] = [dimension, index]
18+
index += dimension
19+
i += 3 # Skip data type and dimension
2020
else
21-
i += 1
21+
i += 1
2222
end
2323
end
2424

@@ -43,8 +43,8 @@ function read_header(data, format::EXYZ)
4343
box = lattice_matrix[diagind(lattice_matrix)]
4444
column_match = match(r"Properties=(.*)", metadata_line)
4545
column_str = mat === nothing ? nothing : column_match.captures[1]
46-
column_info = parse_column_string(column_str, format)
47-
return N, box, column_info, []
46+
column_info = parse_column_string(column_str, format)
47+
return N, box, column_info, split(metadata_line, " ")
4848
end
4949

5050
function get_selrow(::EXYZ, N, m)
@@ -79,7 +79,7 @@ function read_bonds_header(bonds, format::EXYZ)
7979
metadata_line = bonds[2]
8080
column_match = match(r"Properties=(.*)", metadata_line)
8181
column_str = column_match === nothing ? nothing : column_match.captures[1]
82-
column_info = parse_column_string(column_str, format)
82+
column_info = parse_column_string(column_str, format)
8383
return N_bonds, column_info
8484
end
8585

src/IO/xyz.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ end
1111

1212
function parse_column_string(column_str::AbstractString, ::XYZ; d::Int=3)
1313
columns = split(column_str, ",")
14-
column_info = OrderedDict{String, Vector}() # Use OrderedDict to maintain order
14+
column_info = OrderedDict{String,Vector}() # Use OrderedDict to maintain order
1515
index = 1
1616
for column_name in columns
1717
if column_name == "molecule"
@@ -21,10 +21,10 @@ function parse_column_string(column_str::AbstractString, ::XYZ; d::Int=3)
2121
dimension = 1
2222
column_info[column_name] = [dimension, index]
2323
elseif column_name == "position"
24-
column_info["pos"] = [d, index]
24+
column_info["pos"] = [d, index]
2525
elseif column_name == "bond"
2626
dimension = 2
27-
column_info["bond"] = [2, index]
27+
column_info["bond"] = [2, index]
2828
elseif column_name == "btype"
2929
dimension = 1
3030
column_info[column_name] = [dimension, index]
@@ -39,15 +39,15 @@ end
3939
function read_header(data, format::XYZ)
4040
N = parse(Int, data[1]) # Number of atoms or entries
4141
metadata = split(data[2], " ") # Metadata split into an array
42-
42+
4343
# Extract cell vector from metadata
4444
cell_str = replace(metadata[findfirst(startswith("cell:"), metadata)], "cell:" => "")
4545
cell_vector = parse.(Float64, split(cell_str, ","))
4646
d = length(cell_vector)
4747
box = SVector{d}(cell_vector)
4848
column_str = replace(metadata[findfirst(startswith("columns:"), metadata)], "columns:" => "")
4949
column_info = parse_column_string(column_str, format; d=d)
50-
return N, box, column_info, []
50+
return N, box, column_info, metadata
5151
end
5252

5353
function get_system_column(::Atoms, ::XYZ)

0 commit comments

Comments
 (0)