diff --git a/larray/viewer.py b/larray/viewer.py index 3b47648fa..f573cfc09 100644 --- a/larray/viewer.py +++ b/larray/viewer.py @@ -1211,8 +1211,8 @@ def plot(self): row_min, row_max, col_min, col_max = self._selection_bounds() dim_names = self.model().xlabels[0] - xlabels = self.model().xlabels - ylabels = self.model().ylabels + xlabels = self.model().xlabels[1][col_min:col_max] + ylabels = self.model().ylabels[1:][row_min:row_max] assert data.ndim == 2 @@ -1224,27 +1224,29 @@ def plot(self): if data.shape[1] == 1: # plot one column xlabel = ','.join(dim_names[:-1]) - xticklabels = ['\n'.join([str(ylabels[j][r]) for j in range(1, len(ylabels))]) - for r in range(row_min, row_max)] - ax.plot(data[:, 0]) - ax.set_ylabel(xlabels[1][col_min]) + xticklabels = ['\n'.join([str(ylabels[c][r]) for c in range(len(ylabels))]) + for r in range(row_max - row_min)] + xdata = np.arange(row_max - row_min, dtype=int) + ax.plot(xdata, data[:, 0]) + ax.set_ylabel(xlabels[0]) else: # plot each row as a line xlabel = dim_names[-1] - xticklabels = [str(xlabels[1][c]) for c in range(col_min, col_max)] + xticklabels = [str(label) for label in xlabels] + xdata = np.arange(col_max - col_min, dtype=int) for row in range(len(data)): - label = ','.join([str(ylabels[j][row_min + row]) - for j in range(1, len(ylabels))]) - ax.plot(data[row], label=label) + label = ','.join([str(label) for label in ylabels[row]]) + ax.plot(xdata, data[row], label=label) # set x axis ax.set_xlabel(xlabel) - ax.set_xlim(0, len(xticklabels) - 1) + ax.set_xlim((xdata[0], xdata[-1])) # we need to do that because matplotlib is smart enough to # not show all ticks but a selection. However, that selection # may include ticks outside the range of x axis xticks = [t for t in ax.get_xticks().astype(int) if t <= len(xticklabels) - 1] - xticklabels = [xticklabels[j] for j in xticks] + xticklabels = [xticklabels[t] for t in xticks] + ax.set_xticks(xticks) ax.set_xticklabels(xticklabels) if data.shape[1] != 1: