Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
.eggs/
.vscode/
*.swp
.coverage
coverage.xml
7 changes: 7 additions & 0 deletions careless/args/interpretation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@
"default" : None,
}),

(("--unitcell", ), {
"help":"The unit cell parameters to use for merging. Currently only support a single unit"
"cell for all.",
"type":str,
"default" : None,
}),

(("--image-key", ), {
"help":"The name of the key indicating image number for each data set. "
"If no key is given, careless will use the first key with the BATCH dtype.",
Expand Down
19 changes: 16 additions & 3 deletions careless/io/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class DataFormatter():
formatted model inputs
"""
spacegroups = None
cell = None
def pack_inputs(self, inputs_dict):
"""
inputs_dict : {k:v} where k corresponds to one of careless.models.base.BaseModel.input_index.keys()
Expand All @@ -53,7 +54,7 @@ def pack_inputs(self, inputs_dict):
break
return inputs

def prep_dataset(self, ds : rs.DataSet, spacegroup : Optional[gemmi.SpaceGroup] = None) -> rs.DataSet:
def prep_dataset(self, ds : rs.DataSet, spacegroup : Optional[gemmi.SpaceGroup] = None, cell : Optional[gemmi.UnitCell] = None) -> rs.DataSet:
raise NotImplementedError("Formatter classes should implement `prep_dataset`")

def finalize(self, data : rs.DataSet, rac : ReciprocalASUCollection) -> (tuple, ReciprocalASUCollection):
Expand All @@ -80,7 +81,7 @@ def get_data_and_asu_collection(self, datasets):
sg = None
if self.spacegroups is not None:
sg = self.spacegroups[file_id]
ds = self.prep_dataset(ds, sg)
ds = self.prep_dataset(ds, sg, self.cell)

if self.separate_outputs:
asu_id = file_id
Expand Down Expand Up @@ -159,6 +160,7 @@ def __init__(
positional_encoding_keys=None,
encoding_bit_depth=5,
spacegroups = None,
cell = None,
standardize = True,
):
"""
Expand Down Expand Up @@ -194,6 +196,7 @@ def __init__(
self.positional_encoding_keys = positional_encoding_keys
self.ecoding_bit_depth = encoding_bit_depth
self.spacegroups = spacegroups
self.cell = cell
self.standardize = standardize

@classmethod
Expand All @@ -213,6 +216,11 @@ def from_parser(cls, parser):
"Multiple values provided for --spacegroups=, but the number of provided values does not match the number of reflection files. "
"Either provide a single spacegroup or one per reflection file as a comma-separated list. "
)

cell = None
if parser.unitcell is not None:
cell = [float(i) for i in parser.unitcell.split(",")]
cell = gemmi.UnitCell(*cell)

return cls(
parser.intensity_key,
Expand All @@ -226,10 +234,11 @@ def from_parser(cls, parser):
pe_keys,
parser.positional_encoding_frequencies,
spacegroups,
cell,
standardize=parser.standardize_metadata,
)

def prep_dataset(self, ds, spacegroup=None, inplace=True):
def prep_dataset(self, ds, spacegroup=None, cell=None, inplace=True):
"""
Format a single data set.
- Apply resolution cutoff (dHKL >= dmin)
Expand All @@ -245,6 +254,8 @@ def prep_dataset(self, ds, spacegroup=None, inplace=True):
The rs DataSet instance to be standardized
spacegroup : gemmi.SpaceGroup
Optionally override ds.spacegroup with this object.
cell: gemmi.UnitCell
Optionally override ds.cell with this object.
inplace : bool (optional)
By default this method operators inplace on the passed dataset.
Set this parameter to False in order to operate on a copy.
Expand All @@ -258,6 +269,8 @@ def prep_dataset(self, ds, spacegroup=None, inplace=True):

if spacegroup is not None:
ds.spacegroup = spacegroup
if cell is not None:
ds.cell = cell

# Avoid non-unique MultiIndex complications
ds.reset_index(inplace=True)
Expand Down