diff --git a/src/mplfinance/_arg_validators.py b/src/mplfinance/_arg_validators.py index 351eb1e9..ef91eb05 100644 --- a/src/mplfinance/_arg_validators.py +++ b/src/mplfinance/_arg_validators.py @@ -103,22 +103,45 @@ def _get_valid_plot_types(plottype=None): def _mav_validator(mav_value): - ''' + ''' Value for mav (moving average) keyword may be: - scalar int greater than 1, or tuple of ints, or list of ints (greater than 1). - tuple or list limited to length of 7 moving averages (to keep the plot clean). + scalar int greater than 1, or tuple of ints, or list of ints (each greater than 1) + or a dict of `period` and `shift` each of which may be: + scalar int, or tuple of ints, or list of ints: each `period` int must be greater than 1 ''' - if isinstance(mav_value,int) and mav_value > 1: + def _valid_mav(value, is_period=True): + if not isinstance(value,(tuple,list,int)): + return False + if isinstance(value,int): + return (value >= 2 or not is_period) + # Must be a tuple or list here: + for num in value: + if not isinstance(num,int) or (is_period and num < 2): + return False return True - elif not isinstance(mav_value,tuple) and not isinstance(mav_value,list): + + if not isinstance(mav_value,(tuple,list,int,dict)): return False - if not len(mav_value) < 8: + if not isinstance(mav_value,dict): + return _valid_mav(mav_value) + + else: #isinstance(mav_value,dict) + if 'period' not in mav_value: return False + + period = mav_value['period'] + if not _valid_mav(period): return False + + if 'shift' not in mav_value: return True + + shift = mav_value['shift'] + if not _valid_mav(shift, False): return False + if isinstance(period,int) and isinstance(shift,int): return True + if isinstance(period,(tuple,list)) and isinstance(shift,(tuple,list)): + if len(period) != len(shift): return False + return True return False - for num in mav_value: - if not isinstance(num,int) and num > 1: - return False - return True + def _hlines_validator(value): if isinstance(value,dict): diff --git a/src/mplfinance/_version.py b/src/mplfinance/_version.py index f438cd1d..ac383efb 100644 --- a/src/mplfinance/_version.py +++ b/src/mplfinance/_version.py @@ -1,5 +1,5 @@ -version_info = (0, 12, 7, 'alpha', 17) +version_info = (0, 12, 7, 'alpha', 18) _specifier_ = {'alpha': 'a','beta': 'b','candidate': 'rc','final': ''} diff --git a/src/mplfinance/plotting.py b/src/mplfinance/plotting.py index e2219903..c6cccc0b 100644 --- a/src/mplfinance/plotting.py +++ b/src/mplfinance/plotting.py @@ -978,8 +978,12 @@ def _plot_mav(ax,config,xdates,prices,apmav=None,apwidth=None): mavgs = config['mav'] mavp_list = [] if mavgs is not None: + shift = None + if isinstance(mavgs,dict): + shift = mavgs['shift'] + mavgs = mavgs['period'] if isinstance(mavgs,int): - mavgs = mavgs, # convert to tuple + mavgs = mavgs, # convert to tuple if len(mavgs) > 7: mavgs = mavgs[0:7] # take at most 7 @@ -988,8 +992,11 @@ def _plot_mav(ax,config,xdates,prices,apmav=None,apwidth=None): else: mavc = None - for mav in mavgs: - mavprices = pd.Series(prices).rolling(mav).mean().values + for idx,mav in enumerate(mavgs): + mean = pd.Series(prices).rolling(mav).mean() + if shift is not None: + mean = mean.shift(periods=shift[idx]) + mavprices = mean.values lw = config['_width_config']['line_width'] if mavc: ax.plot(xdates, mavprices, linewidth=lw, color=next(mavc)) diff --git a/tests/reference_images/addplot12.png b/tests/reference_images/addplot12.png new file mode 100644 index 00000000..684d8fb0 Binary files /dev/null and b/tests/reference_images/addplot12.png differ diff --git a/tests/test_addplot.py b/tests/test_addplot.py index 5c327f63..5a4867d7 100644 --- a/tests/test_addplot.py +++ b/tests/test_addplot.py @@ -354,3 +354,23 @@ def test_addplot11(bolldata): print('result=',result) assert result is None +def test_addplot12(bolldata): + + df = bolldata + + fname = base+'12.png' + tname = os.path.join(tdir,fname) + rname = os.path.join(refd,fname) + + mpf.plot(df,type='candle',volume=True,savefig=tname,mav={'period':(20,40,60), 'shift': [5,10,20]}) + + tsize = os.path.getsize(tname) + print(glob.glob(tname),'[',tsize,'bytes',']') + + rsize = os.path.getsize(rname) + print(glob.glob(rname),'[',rsize,'bytes',']') + + result = compare_images(rname,tname,tol=IMGCOMP_TOLERANCE) + if result is not None: + print('result=',result) + assert result is None