Skip to content

Commit 1da2891

Browse files
committed
Add flake8-comprehensions plugin
1 parent bc87813 commit 1da2891

27 files changed

+55
-60
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ repos:
3333
rev: 6.0.0
3434
hooks:
3535
- id: flake8
36+
additional_dependencies:
37+
- flake8-comprehensions
3638
- repo: https://github.com/pycqa/isort
3739
rev: 5.12.0
3840
hooks:

pytensor/compile/builders.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -969,9 +969,7 @@ def inline_ofg_expansion(fgraph, node):
969969
return False
970970
if not op.is_inline:
971971
return False
972-
return clone_replace(
973-
op.inner_outputs, {u: v for u, v in zip(op.inner_inputs, node.inputs)}
974-
)
972+
return clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)))
975973

976974

977975
# We want to run this before the first merge optimizer

pytensor/gradient.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ def grad(
504504
if not isinstance(wrt, Sequence):
505505
_wrt: List[Variable] = [wrt]
506506
else:
507-
_wrt = [x for x in wrt]
507+
_wrt = list(wrt)
508508

509509
outputs = []
510510
if cost is not None:
@@ -791,8 +791,8 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False):
791791

792792
pgrads = dict(zip(params, grads))
793793
# separate wrt from end grads:
794-
wrt_grads = list(pgrads[k] for k in wrt)
795-
end_grads = list(pgrads[k] for k in end)
794+
wrt_grads = [pgrads[k] for k in wrt]
795+
end_grads = [pgrads[k] for k in end]
796796

797797
if details:
798798
return wrt_grads, end_grads, start_grads, cost_grads

pytensor/graph/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1629,7 +1629,7 @@ def as_string(
16291629
multi.add(op2)
16301630
else:
16311631
seen.add(input.owner)
1632-
multi_list = [x for x in multi]
1632+
multi_list = list(multi)
16331633
done: Set = set()
16341634

16351635
def multi_index(x):

pytensor/graph/replace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def toposort_key(fg: FunctionGraph, ts, pair):
142142
raise ValueError(f"{key} is not a part of graph")
143143

144144
sorted_replacements = sorted(
145-
tuple(fg_replace.items()),
145+
fg_replace.items(),
146146
# sort based on the fg toposort, if a variable has no owner, it goes first
147147
key=partial(toposort_key, fg, toposort),
148148
reverse=True,

pytensor/graph/rewriting/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2575,8 +2575,8 @@ def print_profile(cls, stream, prof, level=0):
25752575
for i in range(len(loop_timing)):
25762576
loop_times = ""
25772577
if loop_process_count[i]:
2578-
d = list(
2579-
reversed(sorted(loop_process_count[i].items(), key=lambda a: a[1]))
2578+
d = sorted(
2579+
loop_process_count[i].items(), key=lambda a: a[1], reverse=True
25802580
)
25812581
loop_times = " ".join([str((str(k), v)) for k, v in d[:5]])
25822582
if len(d) > 5:

pytensor/link/c/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -633,11 +633,11 @@ def fetch_variables(self):
633633

634634
# The orphans field is listified to ensure a consistent order.
635635
# list(fgraph.orphans.difference(self.outputs))
636-
self.orphans = list(
636+
self.orphans = [
637637
r
638638
for r in self.variables
639639
if isinstance(r, AtomicVariable) and r not in self.inputs
640-
)
640+
]
641641
# C type constants (pytensor.scalar.ScalarType). They don't request an object
642642
self.consts = []
643643
# Move c type from orphans (pytensor.scalar.ScalarType) to self.consts

pytensor/link/c/params_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,7 @@ def c_support_code(self, **kwargs):
810810
struct_extract_method=struct_extract_method,
811811
)
812812

813-
return list(sorted(list(c_support_code_set))) + [final_struct_code]
813+
return sorted(c_support_code_set) + [final_struct_code]
814814

815815
def c_code_cache_version(self):
816816
return ((3,), tuple(t.c_code_cache_version() for t in self.types))

pytensor/link/jax/dispatch/elemwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def careduce(x):
4141
elif scalar_op_name:
4242
scalar_fn_name = scalar_op_name
4343

44-
to_reduce = reversed(sorted(axis))
44+
to_reduce = sorted(axis, reverse=True)
4545

4646
if to_reduce:
4747
# In this case, we need to use the `jax.lax` function (if there

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def careduce_maximum(input):
361361

362362
careduce_fn_name = f"careduce_{scalar_op}"
363363
global_env = {}
364-
to_reduce = reversed(sorted(axes))
364+
to_reduce = sorted(axes, reverse=True)
365365
careduce_lines_src = []
366366
var_name = input_name
367367

0 commit comments

Comments
 (0)