diff --git a/doc/source/whatsnew/v0.15.2.txt b/doc/source/whatsnew/v0.15.2.txt index 929471acb3105..aca1320980720 100644 --- a/doc/source/whatsnew/v0.15.2.txt +++ b/doc/source/whatsnew/v0.15.2.txt @@ -166,3 +166,10 @@ Bug Fixes not lexically sorted or unique (:issue:`7724`) - BUG CSV: fix problem with trailing whitespace in skipped rows, (:issue:`8679`), (:issue:`8661`) - Regression in ``Timestamp`` does not parse 'Z' zone designator for UTC (:issue:`8771`) + + + + + +- Bug in `StataWriter` the produces writes strings with 244 characters irrespective of actual size (:issue:`8969`) + diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 45d3274088c75..cd37efd8e0991 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -1409,7 +1409,7 @@ def _maybe_convert_to_int_keys(convert_dates, varlist): return new_dict -def _dtype_to_stata_type(dtype): +def _dtype_to_stata_type(dtype, column): """ Converts dtype types to stata types. Returns the byte of the given ordinal. See TYPE_MAP and comments for an explanation. This is also explained in @@ -1425,13 +1425,14 @@ def _dtype_to_stata_type(dtype): If there are dates to convert, then dtype will already have the correct type inserted. """ - #TODO: expand to handle datetime to integer conversion + # TODO: expand to handle datetime to integer conversion if dtype.type == np.string_: return chr(dtype.itemsize) elif dtype.type == np.object_: # try to coerce it to the biggest string # not memory efficient, what else could we # do? - return chr(244) + itemsize = max_len_string_array(column.values) + return chr(max(itemsize, 1)) elif dtype == np.float64: return chr(255) elif dtype == np.float32: @@ -1461,6 +1462,7 @@ def _dtype_to_default_stata_fmt(dtype, column): int16 -> "%8.0g" int8 -> "%8.0g" """ + # TODO: Refactor to combine type with format # TODO: expand this to handle a default datetime format? if dtype.type == np.object_: inferred_dtype = infer_dtype(column.dropna()) @@ -1470,8 +1472,7 @@ def _dtype_to_default_stata_fmt(dtype, column): itemsize = max_len_string_array(column.values) if itemsize > 244: raise ValueError(excessive_string_length_error % column.name) - - return "%" + str(itemsize) + "s" + return "%" + str(max(itemsize, 1)) + "s" elif dtype == np.float64: return "%10.0g" elif dtype == np.float32: @@ -1718,10 +1719,11 @@ def _prepare_pandas(self, data): self._convert_dates[key] ) dtypes[key] = np.dtype(new_type) - self.typlist = [_dtype_to_stata_type(dt) for dt in dtypes] + self.typlist = [] self.fmtlist = [] for col, dtype in dtypes.iteritems(): self.fmtlist.append(_dtype_to_default_stata_fmt(dtype, data[col])) + self.typlist.append(_dtype_to_stata_type(dtype, data[col])) # set the given format for the datetime cols if self._convert_dates is not None: diff --git a/pandas/io/tests/test_stata.py b/pandas/io/tests/test_stata.py index a99bcf741792f..6a3c16655745e 100644 --- a/pandas/io/tests/test_stata.py +++ b/pandas/io/tests/test_stata.py @@ -593,10 +593,12 @@ def test_minimal_size_col(self): with tm.ensure_clean() as path: original.to_stata(path, write_index=False) sr = StataReader(path) + typlist = sr.typlist variables = sr.varlist formats = sr.fmtlist - for variable, fmt in zip(variables, formats): + for variable, fmt, typ in zip(variables, formats, typlist): self.assertTrue(int(variable[1:]) == int(fmt[1:-1])) + self.assertTrue(int(variable[1:]) == typ) def test_excessively_long_string(self): str_lens = (1, 244, 500) @@ -850,7 +852,6 @@ def test_categorical_order(self): # Check identity of codes for col in expected: if is_categorical_dtype(expected[col]): - print(col) tm.assert_series_equal(expected[col].cat.codes, parsed_115[col].cat.codes) tm.assert_index_equal(expected[col].cat.categories, diff --git a/pandas/lib.pyx b/pandas/lib.pyx index 82408cd460fcd..2a5b93d111acc 100644 --- a/pandas/lib.pyx +++ b/pandas/lib.pyx @@ -898,17 +898,17 @@ def clean_index_list(list obj): @cython.boundscheck(False) @cython.wraparound(False) -def max_len_string_array(ndarray[object, ndim=1] arr): +def max_len_string_array(ndarray arr): """ return the maximum size of elements in a 1-dim string array """ cdef: int i, m, l - length = arr.shape[0] + int length = arr.shape[0] object v m = 0 for i from 0 <= i < length: v = arr[i] - if PyString_Check(v) or PyBytes_Check(v): + if PyString_Check(v) or PyBytes_Check(v) or PyUnicode_Check(v): l = len(v) if l > m: