diff --git a/plotly_express/_core.py b/plotly_express/_core.py index 20e05c5..9c7ea9e 100644 --- a/plotly_express/_core.py +++ b/plotly_express/_core.py @@ -1,5 +1,6 @@ import plotly.graph_objs as go from plotly.offline import init_notebook_mode, iplot +import plotly.io as pio from collections import namedtuple, OrderedDict from .colors import qualitative, sequential import math @@ -7,15 +8,15 @@ class PxDefaults(object): def __init__(self): - self.color_discrete_sequence = qualitative.Plotly - self.color_continuous_scale = sequential.Plasma + self.template = None + self.width = None + self.height = 600 + self.color_discrete_sequence = None + self.color_continuous_scale = None self.symbol_sequence = ["circle", "diamond", "square", "x", "cross"] self.line_dash_sequence = ["solid", "dot", "dash", "longdash", "dashdot"] + [ "longdashdot" ] - self.template = "plotly" - self.width = None - self.height = 600 self.size_max = 20 @@ -569,7 +570,7 @@ def make_trace_spec(args, constructor, attrs, trace_patch): if "color" in attrs: if "marker" not in trace_spec.trace_patch: trace_spec.trace_patch["marker"] = dict() - first_default_color = defaults.color_discrete_sequence[0] + first_default_color = args["color_discrete_sequence"][0] trace_spec.trace_patch["marker"]["color"] = first_default_color result.append(trace_spec) if "trendline" in args and args["trendline"]: @@ -588,7 +589,8 @@ def one_group(x): return "" -def infer_config(args, constructor, trace_patch): +def apply_default_cascade(args): + # first we apply px.defaults to unspecified args for param in ( ["color_discrete_sequence", "color_continuous_scale"] + ["symbol_sequence", "line_dash_sequence", "template"] @@ -597,6 +599,43 @@ def infer_config(args, constructor, trace_patch): if param in args and args[param] is None: args[param] = getattr(defaults, param) + # load the default template if set, otherwise "plotly" + if args["template"] is None: + if pio.templates.default is not None: + args["template"] = pio.templates.default + else: + args["template"] = "plotly" + + # retrieve the actual template if we were given a name + try: + template = pio.templates[args["template"]] + except Exception: + template = args["template"] + + # if colors not set explicitly or in px.defaults, defer to a template + # if the template doesn't have one, we set some final fallback defaults + if "color_continuous_scale" in args: + if args["color_continuous_scale"] is None: + try: + args["color_continuous_scale"] = [ + x[1] for x in template.layout.colorscale.sequential + ] + except AttributeError: + pass + if args["color_continuous_scale"] is None: + args["color_continuous_scale"] = sequential.Plasma + + if "color_discrete_sequence" in args: + if args["color_discrete_sequence"] is None: + try: + args["color_discrete_sequence"] = template.layout.colorway + except AttributeError: + pass + if args["color_discrete_sequence"] is None: + args["color_discrete_sequence"] = qualitative.Plotly + + +def infer_config(args, constructor, trace_patch): attrables = ( ["x", "y", "z", "a", "b", "c", "r", "theta", "size"] + ["dimensions", "hover_name", "hover_data", "text", "error_x", "error_x_minus"] @@ -631,9 +670,7 @@ def infer_config(args, constructor, trace_patch): sizeref = 0 if "size" in args and args["size"]: - sizeref = args["data_frame"][args["size"]].max() / ( - args["size_max"] * args["size_max"] - ) + sizeref = args["data_frame"][args["size"]].max() / args["size_max"] ** 2 color_range = None if "color" in args: @@ -702,6 +739,7 @@ def infer_config(args, constructor, trace_patch): def make_figure(args, constructor, trace_patch={}, layout_patch={}): + apply_default_cascade(args) trace_specs, grouped_mappings, sizeref, color_range = infer_config( args, constructor, trace_patch )