Skip to content

Commit aa87afc

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent bf0e41e commit aa87afc

14 files changed

+100
-62
lines changed

dpgen2/entrypoint/args.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,11 @@ def lmp_args():
261261
doc=doc_filters,
262262
),
263263
Argument(
264-
"lammps_input_file", str, optional=True, default=None, doc=doc_lammps_input_file
264+
"lammps_input_file",
265+
str,
266+
optional=True,
267+
default=None,
268+
doc=doc_lammps_input_file,
265269
),
266270
]
267271

dpgen2/entrypoint/submit.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@
8181
CustomizedLmpTemplateTaskGroup,
8282
ExplorationStage,
8383
ExplorationTask,
84-
LmpTemplateTaskGroup,
8584
LmpSpinTaskGroup,
85+
LmpTemplateTaskGroup,
8686
NPTTaskGroup,
8787
caly_normalize,
8888
diffcsp_normalize,
@@ -319,7 +319,9 @@ def make_naive_exploration_scheduler_without_conf(config, explore_style):
319319
conv_style = convergence.pop("type")
320320
report = conv_styles[conv_style](**convergence)
321321
# trajectory render, the format of the output trajs are assumed to be lammps/dump
322-
render = TrajRenderLammps(nopbc=output_nopbc,lammps_input_file=config["explore"]["lammps_input_file"])
322+
render = TrajRenderLammps(
323+
nopbc=output_nopbc, lammps_input_file=config["explore"]["lammps_input_file"]
324+
)
323325
# selector
324326
selector = ConfSelectorFrames(
325327
render,
@@ -379,7 +381,11 @@ def make_lmp_naive_exploration_scheduler(config):
379381
# report
380382
conv_style = convergence.pop("type")
381383
report = conv_styles[conv_style](**convergence)
382-
render = TrajRenderLammps(nopbc=output_nopbc, use_ele_temp=use_ele_temp, lammps_input_file=config["explore"]["lammps_input_file"])
384+
render = TrajRenderLammps(
385+
nopbc=output_nopbc,
386+
use_ele_temp=use_ele_temp,
387+
lammps_input_file=config["explore"]["lammps_input_file"],
388+
)
383389
# selector
384390
selector = ConfSelectorFrames(
385391
render,

dpgen2/exploration/render/traj_render_lammps.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(
4242
self,
4343
nopbc: bool = False,
4444
use_ele_temp: int = 0,
45-
lammps_input_file: str = None, # type: ignore
45+
lammps_input_file: str = None, # type: ignore
4646
):
4747
self.nopbc = nopbc
4848
self.use_ele_temp = use_ele_temp
@@ -80,10 +80,10 @@ def _load_one_model_devi(self, fname, model_devi):
8080
model_devi.add(DeviManager.MIN_DEVI_F, dd[:, 5]) # type: ignore
8181
model_devi.add(DeviManager.AVG_DEVI_F, dd[:, 6]) # type: ignore
8282
# assume the 7-9 columns are for MF
83-
if dd.shape[1] >= 10: # type: ignore
84-
model_devi.add(DeviManager.MAX_DEVI_MF, dd[:, 7]) # type: ignore
85-
model_devi.add(DeviManager.MIN_DEVI_MF, dd[:, 8]) # type: ignore
86-
model_devi.add(DeviManager.AVG_DEVI_MF, dd[:, 9]) # type: ignore
83+
if dd.shape[1] >= 10: # type: ignore
84+
model_devi.add(DeviManager.MAX_DEVI_MF, dd[:, 7]) # type: ignore
85+
model_devi.add(DeviManager.MIN_DEVI_MF, dd[:, 8]) # type: ignore
86+
model_devi.add(DeviManager.AVG_DEVI_MF, dd[:, 9]) # type: ignore
8787

8888
def get_ele_temp(self, optional_outputs):
8989
ele_temp = []
@@ -139,7 +139,9 @@ def get_confs(
139139
else:
140140
traj = trajs[ii]
141141
# for spin job, need to read input file to get the key of the spin data
142-
ss = dpdata.System(traj, fmt=traj_fmt, type_map=type_map, input_file=lammps_input_file)
142+
ss = dpdata.System(
143+
traj, fmt=traj_fmt, type_map=type_map, input_file=lammps_input_file
144+
)
143145
ss.nopbc = self.nopbc
144146
if ele_temp:
145147
self.set_ele_temp(ss, ele_temp[ii])

dpgen2/exploration/report/report_adaptive_lower.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,12 @@ def make_class_doc_link(key):
175175
rate_candi_v_link = make_class_doc_link("rate_candi_v")
176176
numb_candi_mf_link = make_class_doc_link("numb_candi_mf")
177177
rate_candi_mf_link = make_class_doc_link("rate_candi_mf")
178-
numb_candi_s = f"{numb_candi_f_link} or {numb_candi_v_link} or {numb_candi_mf_link}"
179-
rate_candi_s = f"{rate_candi_f_link} or {rate_candi_v_link} or {rate_candi_mf_link}"
178+
numb_candi_s = (
179+
f"{numb_candi_f_link} or {numb_candi_v_link} or {numb_candi_mf_link}"
180+
)
181+
rate_candi_s = (
182+
f"{rate_candi_f_link} or {rate_candi_v_link} or {rate_candi_mf_link}"
183+
)
180184
level_f_hi_link = make_class_doc_link("level_f_hi")
181185
level_v_hi_link = make_class_doc_link("level_v_hi")
182186
level_mf_hi_link = make_class_doc_link("level_mf_hi")
@@ -232,7 +236,11 @@ def args() -> List[Argument]:
232236
"numb_candi_mf", int, optional=True, default=0, doc=doc_numb_candi_mf
233237
),
234238
Argument(
235-
"rate_candi_mf", float, optional=True, default=0.0, doc=doc_rate_candi_mf
239+
"rate_candi_mf",
240+
float,
241+
optional=True,
242+
default=0.0,
243+
doc=doc_rate_candi_mf,
236244
),
237245
Argument(
238246
"n_checked_steps", int, optional=True, default=2, doc=doc_n_check_steps
@@ -286,9 +294,14 @@ def record(
286294
coll_mf = []
287295
# loop over trajs
288296
for ii in range(ntraj):
289-
add_nframes, add_accur, add_failed, add_f, add_v, add_mf = self._record_one_traj(
290-
ii, md_f[ii], md_v[ii], md_mf[ii]
291-
)
297+
(
298+
add_nframes,
299+
add_accur,
300+
add_failed,
301+
add_f,
302+
add_v,
303+
add_mf,
304+
) = self._record_one_traj(ii, md_f[ii], md_v[ii], md_mf[ii])
292305
self.nframes += add_nframes
293306
self.accur.update(add_accur)
294307
self.failed += add_failed
@@ -319,14 +332,14 @@ def record(
319332
self.level_v_lo = coll_v[-numb_candi_v][0]
320333
if not self.has_virial:
321334
self.level_v_lo = None
322-
335+
323336
if numb_candi_mf == 0:
324337
self.level_mf_lo = self.level_mf_hi
325338
else:
326339
self.level_mf_lo = coll_mf[-numb_candi_mf][0]
327340
if not self.has_mf:
328341
self.level_mf_lo = None
329-
342+
330343
if numb_candi_f == 0:
331344
self.level_f_lo = self.level_f_hi
332345
else:
@@ -385,7 +398,11 @@ def _record_one_traj(
385398
coll_v = []
386399
coll_mf = []
387400
for ii in range(nframes):
388-
if md_f[ii] > self.level_f_hi or md_v[ii] > self.level_v_hi or md_mf[ii] > self.level_mf_hi:
401+
if (
402+
md_f[ii] > self.level_f_hi
403+
or md_v[ii] > self.level_v_hi
404+
or md_mf[ii] > self.level_mf_hi
405+
):
389406
failed.append((tt, ii))
390407
else:
391408
coll_f.append([md_f[ii], tt, ii])

dpgen2/exploration/report/report_trust_levels_base.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ def __init__(
4444
self.conv_accuracy = conv_accuracy
4545
self.clear()
4646
self.v_level = (self.level_v_lo is not None) and (self.level_v_hi is not None)
47-
self.mf_level = (self.level_mf_lo is not None) and (self.level_mf_hi is not None)
47+
self.mf_level = (self.level_mf_lo is not None) and (
48+
self.level_mf_hi is not None
49+
)
4850
self.model_devi = None
4951

5052
print_tuple = (
@@ -242,16 +244,18 @@ def _record_one_traj(
242244
set_mf_accu = set_full if nomagforce else set(id_mf_accu)
243245
set_mf_cand = set([]) if nomagforce else set(id_mf_cand)
244246
set_mf_fail = set([]) if nomagforce else set(id_mf_fail)
245-
247+
246248
# check consistency
247249
assert set_full == set_f_accu | set_f_cand | set_f_fail
248-
for accu, cand, fail in [[set_f_accu, set_f_cand, set_f_fail],
249-
[set_v_accu, set_v_cand, set_v_fail],
250-
[set_mf_accu, set_mf_cand, set_mf_fail]]:
250+
for accu, cand, fail in [
251+
[set_f_accu, set_f_cand, set_f_fail],
252+
[set_v_accu, set_v_cand, set_v_fail],
253+
[set_mf_accu, set_mf_cand, set_mf_fail],
254+
]:
251255
assert 0 == len(accu & cand)
252256
assert 0 == len(accu & fail)
253257
assert 0 == len(cand & fail)
254-
258+
255259
# accu, cand, fail
256260
set_accu = set_f_accu & set_v_accu & set_mf_accu
257261
set_fail = set_f_fail | set_v_fail | set_mf_fail

dpgen2/exploration/task/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
from .diffcsp_task_group import (
1111
DiffCSPTaskGroup,
1212
)
13-
from .lmp_template_task_group import (
14-
LmpTemplateTaskGroup,
15-
)
1613
from .lmp_spin_task_group import (
1714
LmpSpinTaskGroup,
1815
)
16+
from .lmp_template_task_group import (
17+
LmpTemplateTaskGroup,
18+
)
1919
from .make_task_group_from_config import (
2020
caly_normalize,
2121
caly_task_group_args,

dpgen2/exploration/task/lmp_spin_task_group.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def set_lmp(
4141
numb_models: int,
4242
lmp_template_fname: str,
4343
plm_template_fname: Optional[str] = None,
44-
revisions: dict = {}
44+
revisions: dict = {},
4545
) -> None:
4646
self.lmp_template = Path(lmp_template_fname).read_text().split("\n")
4747
self.revisions = revisions
@@ -104,6 +104,7 @@ def _make_lmp_task(
104104
)
105105
return task
106106

107+
107108
def revise_by_keys(lmp_lines, keys, values):
108109
for kk, vv in zip(keys, values): # type: ignore
109110
for ii in range(len(lmp_lines)):

dpgen2/exploration/task/make_task_group_from_config.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,12 @@
1919
from dpgen2.exploration.task.customized_lmp_template_task_group import (
2020
CustomizedLmpTemplateTaskGroup,
2121
)
22-
from dpgen2.exploration.task.lmp_template_task_group import (
23-
LmpTemplateTaskGroup,
24-
)
2522
from dpgen2.exploration.task.lmp_spin_task_group import (
2623
LmpSpinTaskGroup,
2724
)
28-
25+
from dpgen2.exploration.task.lmp_template_task_group import (
26+
LmpTemplateTaskGroup,
27+
)
2928
from dpgen2.exploration.task.npt_task_group import (
3029
NPTTaskGroup,
3130
)
@@ -180,6 +179,7 @@ def lmp_template_task_group_args():
180179
),
181180
]
182181

182+
183183
def lmp_spin_task_group_args():
184184
doc_lmp_template_fname = "The file name of lammps input template"
185185
doc_plm_template_fname = "The file name of plumed input template"
@@ -210,11 +210,12 @@ def lmp_spin_task_group_args():
210210
alias=["plm_template", "plm"],
211211
),
212212
Argument(
213-
"revisions",
214-
dict,
215-
optional=True,
216-
default={},
217-
doc=doc_revisions,)
213+
"revisions",
214+
dict,
215+
optional=True,
216+
default={},
217+
doc=doc_revisions,
218+
),
218219
]
219220

220221

dpgen2/op/run_dp_train.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -522,9 +522,7 @@ def training_args():
522522
doc_init_model_start_pref_fm = (
523523
"The start magnetic force prefactor in loss when init-model for spin job"
524524
)
525-
doc_spin = (
526-
"If is a spin job"
527-
)
525+
doc_spin = "If is a spin job"
528526
doc_init_model_start_pref_v = (
529527
"The start virial prefactor in loss when init-model"
530528
)

tests/exploration/test_devi_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_devi_manager_std_check_data(self):
8484
model_devi.get,
8585
DeviManager.MAX_DEVI_V,
8686
)
87-
87+
8888
self.assertRaisesRegex(
8989
AssertionError,
9090
"Error: the number of model deviation",
@@ -119,7 +119,7 @@ def test_devi_manager_std_check_data(self):
119119
model_devi.get,
120120
DeviManager.MAX_DEVI_V,
121121
)
122-
122+
123123
model_devi = DeviManagerStd()
124124
model_devi.add(DeviManager.MAX_DEVI_F, np.array([1, 2, 3]))
125125
model_devi.add(DeviManager.MAX_DEVI_F, np.array([4, 5, 6]))

0 commit comments

Comments
 (0)