diff --git a/doc/source/whatsnew/v1.0.0.rst b/doc/source/whatsnew/v1.0.0.rst index e8663853b7684..aa342fc58b38a 100755 --- a/doc/source/whatsnew/v1.0.0.rst +++ b/doc/source/whatsnew/v1.0.0.rst @@ -221,8 +221,8 @@ Other enhancements - DataFrame constructor preserve `ExtensionArray` dtype with `ExtensionArray` (:issue:`11363`) - :meth:`DataFrame.sort_values` and :meth:`Series.sort_values` have gained ``ignore_index`` keyword to be able to reset index after sorting (:issue:`30114`) - :meth:`DataFrame.to_markdown` and :meth:`Series.to_markdown` added (:issue:`11052`) - - :meth:`DataFrame.drop_duplicates` has gained ``ignore_index`` keyword to reset index (:issue:`30114`) +- Added new writer for exporting Stata dta files in version 118, ``StataWriter118``. This format supports exporting strings containing Unicode characters (:issue:`23573`) Build Changes ^^^^^^^^^^^^^ diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 1de0d3b58dc5f..e18b7f50e7723 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -1929,14 +1929,17 @@ def to_stata( >>> df.to_stata('animals.dta') # doctest: +SKIP """ kwargs = {} - if version not in (114, 117): - raise ValueError("Only formats 114 and 117 supported.") + if version not in (114, 117, 118): + raise ValueError("Only formats 114, 117 and 118 are supported.") if version == 114: if convert_strl is not None: - raise ValueError("strl support is only available when using format 117") + raise ValueError("strl is not supported in format 114") from pandas.io.stata import StataWriter as statawriter else: - from pandas.io.stata import StataWriter117 as statawriter + if version == 117: + from pandas.io.stata import StataWriter117 as statawriter + else: + from pandas.io.stata import StataWriter118 as statawriter kwargs["convert_strl"] = convert_strl diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 1f8c6968359c1..b216ee80c3940 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -85,7 +85,7 @@ iterator : bool, default False Return StataReader object.""" -_read_stata_doc = """ +_read_stata_doc = f""" Read Stata file into DataFrame. Parameters @@ -100,10 +100,10 @@ By file-like object, we refer to objects with a ``read()`` method, such as a file handler (e.g. via builtin ``open`` function) or ``StringIO``. -%s -%s -%s -%s +{_statafile_processing_params1} +{_statafile_processing_params2} +{_chunksize_params} +{_iterator_params} Returns ------- @@ -125,33 +125,24 @@ >>> itr = pd.read_stata('filename.dta', chunksize=10000) >>> for chunk in itr: ... do_something(chunk) -""" % ( - _statafile_processing_params1, - _statafile_processing_params2, - _chunksize_params, - _iterator_params, -) +""" -_read_method_doc = """\ +_read_method_doc = f"""\ Reads observations from Stata file, converting them into a dataframe Parameters ---------- nrows : int Number of lines to read from data file, if None read whole file. -%s -%s +{_statafile_processing_params1} +{_statafile_processing_params2} Returns ------- DataFrame -""" % ( - _statafile_processing_params1, - _statafile_processing_params2, -) - +""" -_stata_reader_doc = """\ +_stata_reader_doc = f"""\ Class for reading Stata dta files. Parameters @@ -161,14 +152,10 @@ implementing a binary read() functions. .. versionadded:: 0.23.0 support for pathlib, py.path. -%s -%s -%s -""" % ( - _statafile_processing_params1, - _statafile_processing_params2, - _chunksize_params, -) +{_statafile_processing_params1} +{_statafile_processing_params2} +{_chunksize_params} +""" @Appender(_read_stata_doc) @@ -370,7 +357,7 @@ def convert_delta_safe(base, deltas, unit): month = np.ones_like(dates) conv_dates = convert_year_month_safe(year, month) else: - raise ValueError("Date fmt {fmt} not understood".format(fmt=fmt)) + raise ValueError(f"Date fmt {fmt} not understood") if has_bad_values: # Restore NaT for bad values conv_dates[bad_locs] = NaT @@ -465,9 +452,7 @@ def parse_dates_safe(dates, delta=False, year=False, days=False): d = parse_dates_safe(dates, year=True) conv_dates = d.year else: - raise ValueError( - "Format {fmt} is not a known Stata date format".format(fmt=fmt) - ) + raise ValueError(f"Format {fmt} is not a known Stata date format") conv_dates = Series(conv_dates, dtype=np.float64) missing_value = struct.unpack("= 2 ** 53: - ws = precision_loss_doc % ("uint64", "float64") + ws = precision_loss_doc.format("uint64", "float64") data[col] = data[col].astype(dtype) @@ -585,25 +570,21 @@ def _cast_to_stata_types(data): else: data[col] = data[col].astype(np.float64) if data[col].max() >= 2 ** 53 or data[col].min() <= -(2 ** 53): - ws = precision_loss_doc % ("int64", "float64") + ws = precision_loss_doc.format("int64", "float64") elif dtype in (np.float32, np.float64): value = data[col].max() if np.isinf(value): raise ValueError( - "Column {col} has a maximum value of " - "infinity which is outside the range " - "supported by Stata.".format(col=col) + f"Column {col} has a maximum value of infinity which is outside " + "the range supported by Stata." ) if dtype == np.float32 and value > float32_max: data[col] = data[col].astype(np.float64) elif dtype == np.float64: if value > float64_max: raise ValueError( - "Column {col} has a maximum value " - "({val}) outside the range supported by " - "Stata ({float64_max})".format( - col=col, val=value, float64_max=float64_max - ) + f"Column {col} has a maximum value ({value}) outside the range " + f"supported by Stata ({float64_max})" ) if ws: @@ -618,26 +599,18 @@ class StataValueLabel: Parameters ---------- - value : int8, int16, int32, float32 or float64 - The Stata missing value code - - Attributes - ---------- - string : string - String representation of the Stata missing value - value : int8, int16, int32, float32 or float64 - The original encoded missing value - - Methods - ------- - generate_value_label - + catarray : Categorical + Categorical Series to encode + encoding : {"latin-1", "utf-8"} + Encoding to use for value labels. """ - def __init__(self, catarray): + def __init__(self, catarray, encoding="latin-1"): + if encoding not in ("latin-1", "utf-8"): + raise ValueError("Only latin-1 and utf-8 are supported.") self.labname = catarray.name - + self._encoding = encoding categories = catarray.cat.categories self.value_labels = list(zip(np.arange(len(categories)), categories)) self.value_labels.sort(key=lambda x: x[0]) @@ -656,7 +629,7 @@ def __init__(self, catarray): value_label_mismatch_doc.format(catarray.name), ValueLabelTypeMismatch, ) - + category = category.encode(encoding) self.off.append(self.text_len) self.text_len += len(category) + 1 # +1 for the padding self.val.append(vl[0]) @@ -683,31 +656,31 @@ def _encode(self, s): """ return s.encode(self._encoding) - def generate_value_label(self, byteorder, encoding): + def generate_value_label(self, byteorder): """ + Generate the binary representation of the value labals. + Parameters ---------- byteorder : str Byte order of the output - encoding : str - File encoding Returns ------- value_label : bytes Bytes containing the formatted value label """ - - self._encoding = encoding + encoding = self._encoding bio = BytesIO() - null_string = "\x00" null_byte = b"\x00" # len bio.write(struct.pack(byteorder + "i", self.len)) # labname - labname = self._encode(_pad_bytes(self.labname[:32], 33)) + labname = self.labname[:32].encode(encoding) + lab_len = 32 if encoding not in ("utf-8", "utf8") else 128 + labname = _pad_bytes(labname, lab_len + 1) bio.write(labname) # padding - 3 bytes @@ -731,7 +704,7 @@ def generate_value_label(self, byteorder, encoding): # txt - Text labels, null terminated for text in self.txt: - bio.write(self._encode(text + null_string)) + bio.write(text + null_byte) bio.seek(0) return bio.read() @@ -1007,6 +980,22 @@ def __init__(self): "typedef", "typename", "virtual", + "_all", + "_N", + "_skip", + "_b", + "_pi", + "str#", + "in", + "_pred", + "strL", + "_coef", + "_rc", + "using", + "_cons", + "_se", + "with", + "_n", ) @@ -1192,7 +1181,7 @@ def f(typ): try: return self.TYPE_MAP_XML[typ] except KeyError: - raise ValueError("cannot convert stata types [{0}]".format(typ)) + raise ValueError(f"cannot convert stata types [{typ}]") typlist = [f(x) for x in raw_typlist] @@ -1202,7 +1191,7 @@ def f(typ): try: return self.DTYPE_MAP_XML[typ] except KeyError: - raise ValueError("cannot convert stata dtype [{0}]".format(typ)) + raise ValueError(f"cannot convert stata dtype [{typ}]") dtyplist = [f(x) for x in raw_typlist] @@ -1330,19 +1319,13 @@ def _read_old_header(self, first_char): try: self.typlist = [self.TYPE_MAP[typ] for typ in typlist] except ValueError: - raise ValueError( - "cannot convert stata types [{0}]".format( - ",".join(str(x) for x in typlist) - ) - ) + invalid_types = ",".join(str(x) for x in typlist) + raise ValueError(f"cannot convert stata types [{invalid_types}]") try: self.dtyplist = [self.DTYPE_MAP[typ] for typ in typlist] except ValueError: - raise ValueError( - "cannot convert stata dtypes [{0}]".format( - ",".join(str(x) for x in typlist) - ) - ) + invalid_dtypes = ",".join(str(x) for x in typlist) + raise ValueError(f"cannot convert stata dtypes [{invalid_dtypes}]") if self.format_version > 108: self.varlist = [ @@ -1415,12 +1398,13 @@ def _decode(self, s): except UnicodeDecodeError: # GH 25960, fallback to handle incorrect format produced when 117 # files are converted to 118 files in Stata - msg = """ + encoding = self._encoding + msg = f""" One or more strings in the dta file could not be decoded using {encoding}, and so the fallback encoding of latin-1 is being used. This can happen when a file has been incorrectly encoded by Stata or some other software. You should verify the string values returned are correct.""" - warnings.warn(msg.format(encoding=self._encoding), UnicodeWarning) + warnings.warn(msg, UnicodeWarning) return s.decode("latin-1") def _read_value_labels(self): @@ -1794,7 +1778,7 @@ def _do_convert_categoricals( repeats = list(vc.index[vc > 1]) repeats = "-" * 80 + "\n" + "\n".join(repeats) # GH 25772 - msg = """ + msg = f""" Value labels for column {col} are not unique. These cannot be converted to pandas categoricals. @@ -1805,7 +1789,7 @@ def _do_convert_categoricals( The repeated labels are: {repeats} """ - raise ValueError(msg.format(col=col, repeats=repeats)) + raise ValueError(msg) # TODO: is the next line needed above in the data(...) method? cat_data = Series(cat_data, index=data.index) cat_converted_data.append((col, cat_data)) @@ -1874,13 +1858,15 @@ def _set_endianness(endianness): elif endianness.lower() in [">", "big"]: return ">" else: # pragma : no cover - raise ValueError("Endianness {endian} not understood".format(endian=endianness)) + raise ValueError(f"Endianness {endianness} not understood") def _pad_bytes(name, length): """ Take a char string and pads it with null bytes until it's length chars. """ + if isinstance(name, bytes): + return name + b"\x00" * (length - len(name)) return name + "\x00" * (length - len(name)) @@ -1906,7 +1892,7 @@ def _convert_datetime_to_stata_type(fmt): ]: return np.float64 # Stata expects doubles for SIFs else: - raise NotImplementedError("Format {fmt} not implemented".format(fmt=fmt)) + raise NotImplementedError(f"Format {fmt} not implemented") def _maybe_convert_to_int_keys(convert_dates, varlist): @@ -1956,9 +1942,7 @@ def _dtype_to_stata_type(dtype, column): elif dtype == np.int8: return 251 else: # pragma : no cover - raise NotImplementedError( - "Data type {dtype} not supported.".format(dtype=dtype) - ) + raise NotImplementedError(f"Data type {dtype} not supported.") def _dtype_to_default_stata_fmt(dtype, column, dta_version=114, force_strl=False): @@ -1985,24 +1969,12 @@ def _dtype_to_default_stata_fmt(dtype, column, dta_version=114, force_strl=False if force_strl: return "%9s" if dtype.type == np.object_: - inferred_dtype = infer_dtype(column, skipna=True) - if not (inferred_dtype in ("string", "unicode") or len(column) == 0): - raise ValueError( - "Column `{col}` cannot be exported.\n\nOnly " - "string-like object arrays containing all " - "strings or a mix of strings and None can be " - "exported. Object arrays containing only null " - "values are prohibited. Other object types" - "cannot be exported and must first be converted " - "to one of the supported " - "types.".format(col=column.name) - ) itemsize = max_len_string_array(ensure_object(column.values)) if itemsize > max_str_len: if dta_version >= 117: return "%9s" else: - raise ValueError(excessive_string_length_error % column.name) + raise ValueError(excessive_string_length_error.format(column.name)) return "%" + str(max(itemsize, 1)) + "s" elif dtype == np.float64: return "%10.0g" @@ -2013,9 +1985,7 @@ def _dtype_to_default_stata_fmt(dtype, column, dta_version=114, force_strl=False elif dtype == np.int8 or dtype == np.int16: return "%8.0g" else: # pragma : no cover - raise NotImplementedError( - "Data type {dtype} not supported.".format(dtype=dtype) - ) + raise NotImplementedError(f"Data type {dtype} not supported.") class StataWriter(StataParser): @@ -2043,8 +2013,6 @@ class StataWriter(StataParser): timezone information write_index : bool Write the index to Stata dataset. - encoding : str - Default is latin-1. Only latin-1 and ascii are supported. byteorder : str Can be ">", "<", "little", or "big". default is `sys.byteorder` time_stamp : datetime @@ -2086,6 +2054,7 @@ class StataWriter(StataParser): """ _max_string_length = 244 + _encoding = "latin-1" def __init__( self, @@ -2101,7 +2070,6 @@ def __init__( super().__init__() self._convert_dates = {} if convert_dates is None else convert_dates self._write_index = write_index - self._encoding = "latin-1" self._time_stamp = time_stamp self._data_label = data_label self._variable_labels = variable_labels @@ -2136,7 +2104,8 @@ def _prepare_categoricals(self, data): data_formatted = [] for col, col_is_cat in zip(data, is_cat): if col_is_cat: - self._value_labels.append(StataValueLabel(data[col])) + svl = StataValueLabel(data[col], encoding=self._encoding) + self._value_labels.append(svl) dtype = data[col].cat.codes.dtype if dtype == np.int64: raise ValueError( @@ -2181,6 +2150,36 @@ def _update_strl_names(self): """No-op, forward compatibility""" pass + def _validate_variable_name(self, name): + """ + Validate variable names for Stata export. + + Parameters + ---------- + name : str + Variable name + + Returns + ------- + str + The validated name with invalid characters replaced with + underscores. + + Notes + ----- + Stata 114 and 117 support ascii characters in a-z, A-Z, 0-9 + and _. + """ + for c in name: + if ( + (c < "A" or c > "Z") + and (c < "a" or c > "z") + and (c < "0" or c > "9") + and c != "_" + ): + name = name.replace(c, "_") + return name + def _check_column_names(self, data): """ Checks column names to ensure that they are valid Stata column names. @@ -2204,14 +2203,7 @@ def _check_column_names(self, data): if not isinstance(name, str): name = str(name) - for c in name: - if ( - (c < "A" or c > "Z") - and (c < "a" or c > "z") - and (c < "0" or c > "9") - and c != "_" - ): - name = name.replace(c, "_") + name = self._validate_variable_name(name) # Variable name must not be a reserved word if name in self.RESERVED_WORDS: @@ -2251,7 +2243,7 @@ def _check_column_names(self, data): orig_name = orig_name.encode("utf-8") except (UnicodeDecodeError, AttributeError): pass - msg = "{0} -> {1}".format(orig_name, name) + msg = f"{orig_name} -> {name}" conversion_warning.append(msg) ws = invalid_name_doc.format("\n ".join(conversion_warning)) @@ -2262,12 +2254,12 @@ def _check_column_names(self, data): return data - def _set_formats_and_types(self, data, dtypes): + def _set_formats_and_types(self, dtypes): self.typlist = [] self.fmtlist = [] for col, dtype in dtypes.items(): - self.fmtlist.append(_dtype_to_default_stata_fmt(dtype, data[col])) - self.typlist.append(_dtype_to_stata_type(dtype, data[col])) + self.fmtlist.append(_dtype_to_default_stata_fmt(dtype, self.data[col])) + self.typlist.append(_dtype_to_stata_type(dtype, self.data[col])) def _prepare_pandas(self, data): # NOTE: we might need a different API / class for pandas objects so @@ -2311,17 +2303,57 @@ def _prepare_pandas(self, data): new_type = _convert_datetime_to_stata_type(self._convert_dates[key]) dtypes[key] = np.dtype(new_type) - self._set_formats_and_types(data, dtypes) + # Verify object arrays are strings and encode to bytes + self._encode_strings() + + self._set_formats_and_types(dtypes) # set the given format for the datetime cols if self._convert_dates is not None: for key in self._convert_dates: self.fmtlist[key] = self._convert_dates[key] + def _encode_strings(self): + """ + Encode strings in dta-specific encoding + + Do not encode columns marked for date conversion or for strL + conversion. The strL converter independently handles conversion and + also accepts empty string arrays. + """ + convert_dates = self._convert_dates + # _convert_strl is not available in dta 114 + convert_strl = getattr(self, "_convert_strl", []) + for i, col in enumerate(self.data): + # Skip columns marked for date conversion or strl conversion + if i in convert_dates or col in convert_strl: + continue + column = self.data[col] + dtype = column.dtype + if dtype.type == np.object_: + inferred_dtype = infer_dtype(column, skipna=True) + if not ((inferred_dtype in ("string", "unicode")) or len(column) == 0): + col = column.name + raise ValueError( + f"""\ +Column `{col}` cannot be exported.\n\nOnly string-like object arrays +containing all strings or a mix of strings and None can be exported. +Object arrays containing only null values are prohibited. Other object +types cannot be exported and must first be converted to one of the +supported types.""" + ) + encoded = self.data[col].str.encode(self._encoding) + # If larger than _max_string_length do nothing + if ( + max_len_string_array(ensure_object(encoded.values)) + <= self._max_string_length + ): + self.data[col] = encoded + def write_file(self): self._file, self._own_file = _open_file_binary_write(self._fname) try: - self._write_header(time_stamp=self._time_stamp, data_label=self._data_label) + self._write_header(data_label=self._data_label, time_stamp=self._time_stamp) self._write_map() self._write_variable_types() self._write_varnames() @@ -2344,9 +2376,8 @@ def write_file(self): os.unlink(self._fname) except OSError: warnings.warn( - "This save was not successful but {0} could not " - "be deleted. This file is not " - "valid.".format(self._fname), + f"This save was not successful but {self._fname} could not " + "be deleted. This file is not valid.", ResourceWarning, ) raise exc @@ -2392,7 +2423,7 @@ def _write_expansion_fields(self): def _write_value_labels(self): for vl in self._value_labels: - self._file.write(vl.generate_value_label(self._byteorder, self._encoding)) + self._file.write(vl.generate_value_label(self._byteorder)) def _write_header(self, data_label=None, time_stamp=None): byteorder = self._byteorder @@ -2494,9 +2525,8 @@ def _write_variable_labels(self): is_latin1 = all(ord(c) < 256 for c in label) if not is_latin1: raise ValueError( - "Variable labels must contain only " - "characters that can be encoded in " - "Latin-1" + "Variable labels must contain only characters that " + "can be encoded in Latin-1" ) self._write(_pad_bytes(label, 81)) else: @@ -2527,9 +2557,9 @@ def _prepare_data(self): typ = typlist[i] if typ <= self._max_string_length: data[col] = data[col].fillna("").apply(_pad_bytes, args=(typ,)) - stype = "S{type}".format(type=typ) + stype = f"S{typ}" dtypes[col] = stype - data[col] = data[col].str.encode(self._encoding).astype(stype) + data[col] = data[col].astype(stype) else: dtype = data[col].dtype if not native_byteorder: @@ -2715,12 +2745,6 @@ def generate_table(self): return gso_table, gso_df - def _encode(self, s): - """ - Python 3 compatibility shim - """ - return s.encode(self._encoding) - def generate_blob(self, gso_table): """ Generates the binary blob of GSOs that is written to the dta file. @@ -2860,6 +2884,7 @@ class StataWriter117(StataWriter): """ _max_string_length = 2045 + _dta_version = 117 def __init__( self, @@ -2906,18 +2931,21 @@ def _write_header(self, data_label=None, time_stamp=None): self._file.write(bytes("", "utf-8")) bio = BytesIO() # ds_format - 117 - bio.write(self._tag(bytes("117", "utf-8"), "release")) + bio.write(self._tag(bytes(str(self._dta_version), "utf-8"), "release")) # byteorder bio.write(self._tag(byteorder == ">" and "MSF" or "LSF", "byteorder")) # number of vars, 2 bytes assert self.nvar < 2 ** 16 bio.write(self._tag(struct.pack(byteorder + "H", self.nvar), "K")) - # number of obs, 4 bytes - bio.write(self._tag(struct.pack(byteorder + "I", self.nobs), "N")) + # 117 uses 4 bytes, 118 uses 8 + nobs_size = "I" if self._dta_version == 117 else "Q" + bio.write(self._tag(struct.pack(byteorder + nobs_size, self.nobs), "N")) # data label 81 bytes, char, null terminated label = data_label[:80] if data_label is not None else "" - label_len = struct.pack(byteorder + "B", len(label)) - label = label_len + bytes(label, "utf-8") + label = label.encode(self._encoding) + label_size = "B" if self._dta_version == 117 else "H" + label_len = struct.pack(byteorder + label_size, len(label)) + label = label_len + label bio.write(self._tag(label, "label")) # time stamp, 18 bytes, char, null terminated # format dd Mon yyyy hh:mm @@ -2947,7 +2975,7 @@ def _write_header(self, data_label=None, time_stamp=None): + time_stamp.strftime(" %Y %H:%M") ) # '\x11' added due to inspection of Stata file - ts = b"\x11" + bytes(ts, "utf8") + ts = b"\x11" + bytes(ts, "utf-8") bio.write(self._tag(ts, "timestamp")) bio.seek(0) self._file.write(self._tag(bio.read(), "header")) @@ -2994,9 +3022,11 @@ def _write_variable_types(self): def _write_varnames(self): self._update_map("varnames") bio = BytesIO() + # 118 scales by 4 to accommodate utf-8 data worst case encoding + vn_len = 32 if self._dta_version == 117 else 128 for name in self.varlist: name = self._null_terminate(name, True) - name = _pad_bytes_new(name[:32], 33) + name = _pad_bytes_new(name[:32].encode(self._encoding), vn_len + 1) bio.write(name) bio.seek(0) self._file.write(self._tag(bio.read(), "varnames")) @@ -3008,21 +3038,24 @@ def _write_sortlist(self): def _write_formats(self): self._update_map("formats") bio = BytesIO() + fmt_len = 49 if self._dta_version == 117 else 57 for fmt in self.fmtlist: - bio.write(_pad_bytes_new(fmt, 49)) + bio.write(_pad_bytes_new(fmt.encode(self._encoding), fmt_len)) bio.seek(0) self._file.write(self._tag(bio.read(), "formats")) def _write_value_label_names(self): self._update_map("value_label_names") bio = BytesIO() + # 118 scales by 4 to accommodate utf-8 data worst case encoding + vl_len = 32 if self._dta_version == 117 else 128 for i in range(self.nvar): # Use variable name when categorical name = "" # default name if self._is_col_cat[i]: name = self.varlist[i] name = self._null_terminate(name, True) - name = _pad_bytes_new(name[:32], 33) + name = _pad_bytes_new(name[:32].encode(self._encoding), vl_len + 1) bio.write(name) bio.seek(0) self._file.write(self._tag(bio.read(), "value_label_names")) @@ -3031,7 +3064,9 @@ def _write_variable_labels(self): # Missing labels are 80 blank characters plus null termination self._update_map("variable_labels") bio = BytesIO() - blank = _pad_bytes_new("", 81) + # 118 scales by 4 to accommodate utf-8 data worst case encoding + vl_len = 80 if self._dta_version == 117 else 320 + blank = _pad_bytes_new("", vl_len + 1) if self._variable_labels is None: for _ in range(self.nvar): @@ -3045,14 +3080,15 @@ def _write_variable_labels(self): label = self._variable_labels[col] if len(label) > 80: raise ValueError("Variable labels must be 80 characters or fewer") - is_latin1 = all(ord(c) < 256 for c in label) - if not is_latin1: + try: + encoded = label.encode(self._encoding) + except UnicodeEncodeError: raise ValueError( - "Variable labels must contain only " - "characters that can be encoded in " - "Latin-1" + "Variable labels must contain only characters that " + f"can be encoded in {self._encoding}" ) - bio.write(_pad_bytes_new(label, 81)) + + bio.write(_pad_bytes_new(encoded, vl_len + 1)) else: bio.write(blank) bio.seek(0) @@ -3084,7 +3120,7 @@ def _write_value_labels(self): self._update_map("value_labels") bio = BytesIO() for vl in self._value_labels: - lab = vl.generate_value_label(self._byteorder, self._encoding) + lab = vl.generate_value_label(self._byteorder) lab = self._tag(lab, "lbl") bio.write(lab) bio.seek(0) @@ -3114,19 +3150,140 @@ def _convert_strls(self, data): ] if convert_cols: - ssw = StataStrLWriter(data, convert_cols) + ssw = StataStrLWriter(data, convert_cols, version=self._dta_version) tab, new_data = ssw.generate_table() data = new_data self._strl_blob = ssw.generate_blob(tab) return data - def _set_formats_and_types(self, data, dtypes): + def _set_formats_and_types(self, dtypes): self.typlist = [] self.fmtlist = [] for col, dtype in dtypes.items(): force_strl = col in self._convert_strl fmt = _dtype_to_default_stata_fmt( - dtype, data[col], dta_version=117, force_strl=force_strl + dtype, + self.data[col], + dta_version=self._dta_version, + force_strl=force_strl, ) self.fmtlist.append(fmt) - self.typlist.append(_dtype_to_stata_type_117(dtype, data[col], force_strl)) + self.typlist.append( + _dtype_to_stata_type_117(dtype, self.data[col], force_strl) + ) + + +class StataWriter118(StataWriter117): + """ + A class for writing Stata binary dta files in Stata 15 format (118) + + DTA 118 format files support unicode string data (both fixed and strL) + format. Unicode is also supported in value labels, variable labels and + the dataset label. + + .. versionadded:: 1.0.0 + + Parameters + ---------- + fname : path (string), buffer or path object + string, path object (pathlib.Path or py._path.local.LocalPath) or + object implementing a binary write() functions. If using a buffer + then the buffer will not be automatically closed after the file + is written. + data : DataFrame + Input to save + convert_dates : dict + Dictionary mapping columns containing datetime types to stata internal + format to use when writing the dates. Options are 'tc', 'td', 'tm', + 'tw', 'th', 'tq', 'ty'. Column can be either an integer or a name. + Datetime columns that do not have a conversion type specified will be + converted to 'tc'. Raises NotImplementedError if a datetime column has + timezone information + write_index : bool + Write the index to Stata dataset. + byteorder : str + Can be ">", "<", "little", or "big". default is `sys.byteorder` + time_stamp : datetime + A datetime to use as file creation date. Default is the current time + data_label : str + A label for the data set. Must be 80 characters or smaller. + variable_labels : dict + Dictionary containing columns as keys and variable labels as values. + Each label must be 80 characters or smaller. + convert_strl : list + List of columns names to convert to Stata StrL format. Columns with + more than 2045 characters are automatically written as StrL. + Smaller columns can be converted by including the column name. Using + StrLs can reduce output file size when strings are longer than 8 + characters, and either frequently repeated or sparse. + + Returns + ------- + StataWriter118 + The instance has a write_file method, which will write the file to the + given `fname`. + + Raises + ------ + NotImplementedError + * If datetimes contain timezone information + ValueError + * Columns listed in convert_dates are neither datetime64[ns] + or datetime.datetime + * Column dtype is not representable in Stata + * Column listed in convert_dates is not in DataFrame + * Categorical label contains more than 32,000 characters + + Examples + -------- + Using Unicode data and column names + + >>> from pandas.io.stata import StataWriter118 + >>> data = pd.DataFrame([[1.0, 1, 'ᴬ']], columns=['a', 'β', 'ĉ']) + >>> writer = StataWriter118('./data_file.dta', data) + >>> writer.write_file() + + Or with long strings stored in strl format + + >>> data = pd.DataFrame([['ᴀ relatively long ŝtring'], [''], ['']], + ... columns=['strls']) + >>> writer = StataWriter118('./data_file_with_long_strings.dta', data, + ... convert_strl=['strls']) + >>> writer.write_file() + """ + + _encoding = "utf-8" + _dta_version = 118 + + def _validate_variable_name(self, name): + """ + Validate variable names for Stata export. + + Parameters + ---------- + name : str + Variable name + + Returns + ------- + str + The validated name with invalid characters replaced with + underscores. + + Notes + ----- + Stata 118 support most unicode characters. The only limatation is in + the ascii range where the characters supported are a-z, A-Z, 0-9 and _. + """ + # High code points appear to be acceptable + for c in name: + if ( + ord(c) < 128 + and (c < "A" or c > "Z") + and (c < "a" or c > "z") + and (c < "0" or c > "9") + and c != "_" + ) or 128 <= ord(c) < 256: + name = name.replace(c, "_") + + return name diff --git a/pandas/tests/io/test_stata.py b/pandas/tests/io/test_stata.py index cbc5ebd986c15..e8bc7f480fb1d 100644 --- a/pandas/tests/io/test_stata.py +++ b/pandas/tests/io/test_stata.py @@ -21,6 +21,7 @@ PossiblePrecisionLoss, StataMissingValue, StataReader, + StataWriter118, read_stata, ) @@ -1271,11 +1272,9 @@ def test_invalid_variable_labels(self, version): variable_labels["a"] = "invalid character Œ" with tm.ensure_clean() as path: - msg = ( - "Variable labels must contain only characters that can be" - " encoded in Latin-1" - ) - with pytest.raises(ValueError, match=msg): + with pytest.raises( + ValueError, match="Variable labels must contain only characters" + ): original.to_stata( path, variable_labels=variable_labels, version=version ) @@ -1425,8 +1424,8 @@ def test_out_of_range_double(self): } ) msg = ( - r"Column ColumnTooBig has a maximum value \(.+\)" - r" outside the range supported by Stata \(.+\)" + r"Column ColumnTooBig has a maximum value \(.+\) outside the range " + r"supported by Stata \(.+\)" ) with pytest.raises(ValueError, match=msg): with tm.ensure_clean() as path: @@ -1434,8 +1433,8 @@ def test_out_of_range_double(self): df.loc[2, "ColumnTooBig"] = np.inf msg = ( - "Column ColumnTooBig has a maximum value of infinity which" - " is outside the range supported by Stata" + "Column ColumnTooBig has a maximum value of infinity which is outside " + "the range supported by Stata" ) with pytest.raises(ValueError, match=msg): with tm.ensure_clean() as path: @@ -1706,15 +1705,7 @@ def test_all_none_exception(self, version): output = pd.DataFrame(output) output.loc[:, "none"] = None with tm.ensure_clean() as path: - msg = ( - r"Column `none` cannot be exported\.\n\n" - "Only string-like object arrays containing all strings or a" - r" mix of strings and None can be exported\. Object arrays" - r" containing only null values are prohibited\. Other" - " object typescannot be exported and must first be" - r" converted to one of the supported types\." - ) - with pytest.raises(ValueError, match=msg): + with pytest.raises(ValueError, match="Column `none` cannot be exported"): output.to_stata(path, version=version) @pytest.mark.parametrize("version", [114, 117]) @@ -1778,3 +1769,41 @@ def test_stata_119(self): assert df.iloc[0, 7] == 3.14 assert df.iloc[0, -1] == 1 assert df.iloc[0, 0] == pd.Timestamp(datetime(2012, 12, 21, 21, 12, 21)) + + def test_118_writer(self): + cat = pd.Categorical(["a", "β", "ĉ"], ordered=True) + data = pd.DataFrame( + [ + [1.0, 1, "ᴬ", "ᴀ relatively long ŝtring"], + [2.0, 2, "ᴮ", ""], + [3.0, 3, "ᴰ", None], + ], + columns=["a", "β", "ĉ", "strls"], + ) + data["ᴐᴬᵀ"] = cat + variable_labels = { + "a": "apple", + "β": "ᵈᵉᵊ", + "ĉ": "ᴎტჄႲႳႴႶႺ", + "strls": "Long Strings", + "ᴐᴬᵀ": "", + } + data_label = "ᴅaᵀa-label" + data["β"] = data["β"].astype(np.int32) + with tm.ensure_clean() as path: + writer = StataWriter118( + path, + data, + data_label=data_label, + convert_strl=["strls"], + variable_labels=variable_labels, + write_index=False, + ) + writer.write_file() + reread_encoded = read_stata(path) + # Missing is intentionally converted to empty strl + data["strls"] = data["strls"].fillna("") + tm.assert_frame_equal(data, reread_encoded) + reader = StataReader(path) + assert reader.data_label == data_label + assert reader.variable_labels() == variable_labels