Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
BUG: Fix bugs in stata
Fix incorrect skipping in strl writer Fix incorrect byteorder when exporting bigendian Fix incorrect byteorder parsing when importing bigendian Improve test coverage for errors
  • Loading branch information
bashtage committed May 1, 2018
commit a5f16532c0fa24e2715a57b35ae669c1f4ebd8c7
50 changes: 26 additions & 24 deletions pandas/io/stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,7 +1050,7 @@ def _read_new_header(self, first_char):
if self.format_version not in [117, 118]:
raise ValueError(_version_error)
self.path_or_buf.read(21) # </release><byteorder>
self.byteorder = self.path_or_buf.read(3) == "MSF" and '>' or '<'
self.byteorder = self.path_or_buf.read(3) == b'MSF' and '>' or '<'
self.path_or_buf.read(15) # </byteorder><K>
self.nvar = struct.unpack(self.byteorder + 'H',
self.path_or_buf.read(2))[0]
Expand Down Expand Up @@ -1824,9 +1824,7 @@ def _dtype_to_stata_type(dtype, column):
type inserted.
"""
# TODO: expand to handle datetime to integer conversion
if dtype.type == np.string_:
return dtype.itemsize
elif dtype.type == np.object_: # try to coerce it to the biggest string
if dtype.type == np.object_: # try to coerce it to the biggest string
# not memory efficient, what else could we
# do?
itemsize = max_len_string_array(_ensure_object(column.values))
Expand Down Expand Up @@ -2347,25 +2345,30 @@ def _prepare_data(self):
data = self._convert_strls(data)

# 3. Convert bad string data to '' and pad to correct length
dtype = []
dtypes = []
data_cols = []
has_strings = False
native_byteorder = self._byteorder == _set_endianness(sys.byteorder)
for i, col in enumerate(data):
typ = typlist[i]
if typ <= self._max_string_length:
has_strings = True
data[col] = data[col].fillna('').apply(_pad_bytes, args=(typ,))
stype = 'S%d' % typ
dtype.append(('c' + str(i), stype))
dtypes.append(('c' + str(i), stype))
string = data[col].str.encode(self._encoding)
data_cols.append(string.values.astype(stype))
else:
dtype.append(('c' + str(i), data[col].dtype))
data_cols.append(data[col].values)
dtype = np.dtype(dtype)

if has_strings:
self.data = np.fromiter(zip(*data_cols), dtype=dtype)
values = data[col].values
dtype = data[col].dtype
if not native_byteorder:
dtype = dtype.newbyteorder(self._byteorder)
dtypes.append(('c' + str(i), dtype))
data_cols.append(values)
dtypes = np.dtype(dtypes)

if has_strings or not native_byteorder:
self.data = np.fromiter(zip(*data_cols), dtype=dtypes)
else:
self.data = data.to_records(index=False)

Expand Down Expand Up @@ -2403,9 +2406,7 @@ def _dtype_to_stata_type_117(dtype, column, force_strl):
# TODO: expand to handle datetime to integer conversion
if force_strl:
return 32768
if dtype.type == np.string_:
return chr(dtype.itemsize)
elif dtype.type == np.object_: # try to coerce it to the biggest string
if dtype.type == np.object_: # try to coerce it to the biggest string
# not memory efficient, what else could we
# do?
itemsize = max_len_string_array(_ensure_object(column.values))
Expand Down Expand Up @@ -2513,11 +2514,13 @@ def generate_table(self):
Ordered dictionary using the string found as keys
and their lookup position (v,o) as values
gso_df : DataFrame
Copy of DataFrame where strl columns have been converted
to encoded (v,o) values
DataFrame where strl columns have been converted to
(v,o) values

Notes
-----
Modifies the DataFrame in-place.

The DataFrame returned encodes the (v,o) values as uint64s. The
encoding depends on teh dta version, and can be expressed as

Expand All @@ -2532,10 +2535,9 @@ def generate_table(self):
"""

gso_table = self._gso_table
df_out = self.df.copy()
df = self.df
columns = list(df.columns)
selected = df[self.columns]
gso_df = self.df
columns = list(gso_df.columns)
selected = gso_df[self.columns]
col_index = [(col, columns.index(col)) for col in self.columns]
keys = np.empty(selected.shape, dtype=np.uint64)
for o, (idx, row) in enumerate(selected.iterrows()):
Expand All @@ -2548,9 +2550,9 @@ def generate_table(self):
gso_table[val] = key
keys[o, j] = self._convert_key(key)
for i, col in enumerate(self.columns):
df_out[col] = keys[:, i]
gso_df[col] = keys[:, i]

return gso_table, df_out
return gso_table, gso_df

def _encode(self, s):
"""
Expand Down Expand Up @@ -2599,7 +2601,7 @@ def generate_blob(self, table):
o_type = self._byteorder + self._gso_o_type
len_type = self._byteorder + 'I'
for strl, vo in table.items():
if vo == 0:
if vo == (0, 0):
continue
v, o = vo
# GSO
Expand Down
61 changes: 58 additions & 3 deletions pandas/tests/io/test_stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,15 @@ def test_timestamp_and_label(self, version):
assert reader.time_stamp == '29 Feb 2000 14:21'
assert reader.data_label == data_label

@pytest.mark.parametrize('version', [114, 117])
def test_invalid_timestamp(self, version):
original = DataFrame([(1,)], columns=['variable'])
time_stamp = '01 Jan 2000, 00:00:00'
with tm.ensure_clean() as path:
with pytest.raises(ValueError):
original.to_stata(path, time_stamp=time_stamp,
version=version)

def test_numeric_column_names(self):
original = DataFrame(np.reshape(np.arange(25.0), (5, 5)))
original.index.name = 'index'
Expand Down Expand Up @@ -639,7 +648,8 @@ def test_write_missing_strings(self):
expected)

@pytest.mark.parametrize('version', [114, 117])
def test_bool_uint(self, version):
@pytest.mark.parametrize('byteorder', ['>', '<'])
def test_bool_uint(self, byteorder, version):
s0 = Series([0, 1, True], dtype=np.bool)
s1 = Series([0, 1, 100], dtype=np.uint8)
s2 = Series([0, 1, 255], dtype=np.uint8)
Expand All @@ -658,7 +668,7 @@ def test_bool_uint(self, version):
expected[c] = expected[c].astype(t)

with tm.ensure_clean() as path:
original.to_stata(path, version=version)
original.to_stata(path, byteorder=byteorder, version=version)
written_and_read_again = self.read_dta(path)
written_and_read_again = written_and_read_again.set_index('index')
tm.assert_frame_equal(written_and_read_again, expected)
Expand Down Expand Up @@ -1173,6 +1183,29 @@ def test_write_variable_labels(self, version):
read_labels = sr.variable_labels()
assert read_labels == variable_labels

@pytest.mark.parametrize('version', [114, 117])
def test_invalid_variable_labels(self, version):
original = pd.DataFrame({'a': [1, 2, 3, 4],
'b': [1.0, 3.0, 27.0, 81.0],
'c': ['Atlanta', 'Birmingham',
'Cincinnati', 'Detroit']})
original.index.name = 'index'
variable_labels = {'a': 'very long' * 10,
'b': 'City Exponent',
'c': 'City'}
with tm.ensure_clean() as path:
with pytest.raises(ValueError):
original.to_stata(path,
variable_labels=variable_labels,
version=version)

variable_labels['a'] = u'invalid character Œ'
with tm.ensure_clean() as path:
with pytest.raises(ValueError):
original.to_stata(path,
variable_labels=variable_labels,
version=version)

def test_write_variable_label_errors(self):
original = pd.DataFrame({'a': [1, 2, 3, 4],
'b': [1.0, 3.0, 27.0, 81.0],
Expand Down Expand Up @@ -1220,6 +1253,13 @@ def test_default_date_conversion(self):
direct = read_stata(path, convert_dates=True)
tm.assert_frame_equal(reread, direct)

dates_idx = original.columns.tolist().index('dates')
original.to_stata(path,
write_index=False,
convert_dates={dates_idx: 'tc'})
direct = read_stata(path, convert_dates=True)
tm.assert_frame_equal(reread, direct)

def test_unsupported_type(self):
original = pd.DataFrame({'a': [1 + 2j, 2 + 4j]})

Expand Down Expand Up @@ -1394,7 +1434,7 @@ def test_writer_117(self):
original['float32'] = Series(original['float32'], dtype=np.float32)
original.index.name = 'index'
original.index = original.index.astype(np.int32)

copy = original.copy()
with tm.ensure_clean() as path:
original.to_stata(path,
convert_dates={'datetime': 'tc'},
Expand All @@ -1404,6 +1444,7 @@ def test_writer_117(self):
# original.index is np.int32, read index is np.int64
tm.assert_frame_equal(written_and_read_again.set_index('index'),
original, check_index_type=False)
tm.assert_frame_equal(original, copy)

def test_convert_strl_name_swap(self):
original = DataFrame([['a' * 3000, 'A', 'apple'],
Expand All @@ -1419,3 +1460,17 @@ def test_convert_strl_name_swap(self):
reread.columns = original.columns
tm.assert_frame_equal(reread, original,
check_index_type=False)

def test_invalid_date_conversion(self):
# GH 12259
dates = [dt.datetime(1999, 12, 31, 12, 12, 12, 12000),
dt.datetime(2012, 12, 21, 12, 21, 12, 21000),
dt.datetime(1776, 7, 4, 7, 4, 7, 4000)]
original = pd.DataFrame({'nums': [1.0, 2.0, 3.0],
'strs': ['apple', 'banana', 'cherry'],
'dates': dates})

with tm.ensure_clean() as path:
with pytest.raises(ValueError):
original.to_stata(path,
convert_dates={'wrong_name': 'tc'})