mysql.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. from typing import Dict, Hashable, Iterable, List, MutableMapping, Union, Tuple, Optional, Any
  2. from .dbi import DBI, pd
  3. class MySQL(DBI):
  4. def __init__(self,
  5. host: Optional[str] = None,
  6. user: Optional[str] = None,
  7. password: Optional[str] = None,
  8. scheme: Optional[str] = None):
  9. super().__init__(host, user, password, scheme)
  10. import mysql.connector
  11. self.BACKEND = mysql
  12. self.conn = self.BACKEND.connector.connect(host=self.host,
  13. user=self.user,
  14. password=self.password,
  15. database=self.scheme)
  16. self.cursor = self.conn.cursor()
  17. def write(self,
  18. query: str) \
  19. -> DBI.WriteReturn:
  20. try:
  21. self.cursor.execute(query)
  22. self.conn.commit()
  23. return DBI.WriteReturn(rows=self.cursor.rowcount,
  24. key=self.cursor.lastrowid)
  25. except (self.BACKEND.connector.InterfaceError,
  26. self.BACKEND.connector.OperationalError):
  27. self.reconnect()
  28. return self.write(query)
  29. except Exception as e:
  30. raise Exception(f'MySQL raised an error while writing: {e.args}. SQL: {query}')
  31. def insert(self,
  32. table: str,
  33. rows: Union[pd.DataFrame, Dict[str, Any], List[Dict[str, Any]]],
  34. ignore: bool = False,
  35. update: bool = False,
  36. schema: Optional[str] = None) \
  37. -> DBI.WriteReturn:
  38. sql: str = f'INSERT {"IGNORE" if ignore else ""} INTO {"`"+schema+"`." if schema else ""}`{table}` '
  39. prep_rows: List[Dict[str, Any]] = []
  40. if type(rows) is dict:
  41. columns: Iterable[str] = rows.keys()
  42. prep_rows = [rows]
  43. elif type(rows) is pd.DataFrame:
  44. columns = rows.columns
  45. prep_rows = rows.to_dict(orient='records', index=False) # type: ignore
  46. elif type(rows) is list:
  47. columns = rows[0].keys()
  48. prep_rows = rows
  49. else:
  50. raise ValueError(f'Cannot extract columns from type "{type(rows).__qualname__}"')
  51. sql += f'({",".join([f"`{c}`" for c in columns])}) VALUES '
  52. # Create each row entry
  53. _data: List[str] = []
  54. for row in prep_rows:
  55. _row: List[str] = []
  56. for column in columns:
  57. _col = row.get(column)
  58. _row.append(self._parse_value(_col))
  59. _data.append(f'({",".join(_row)})')
  60. # Merge row entries to single entry
  61. sql += f'{",".join(_data)} AS `new` '
  62. if update:
  63. sql += f'ON DUPLICATE KEY UPDATE {",".join([f"`{c}`=`new`.`{c}`" for c in columns])}'
  64. return self.write(sql)
  65. def update(self,
  66. table: str,
  67. ids: Dict[str, Any],
  68. values: Dict[str, Any],
  69. schema: Optional[str] = None) \
  70. -> DBI.WriteReturn:
  71. sql: str = f'UPDATE {"`"+schema+"`." if schema else ""}`{table}` '
  72. _val: List[str] = []
  73. for key in values.keys():
  74. _val.append(f'`{key}`={self._parse_value(values.get(key))}')
  75. sql += f'SET {",".join(_val)} '
  76. _id: List[str] = []
  77. for key in ids.keys():
  78. _id.append(f'`{key}`={self._parse_value(ids.get(key))}')
  79. sql += f'WHERE {" AND ".join(_id)}'
  80. return self.write(sql)
  81. def read(self,
  82. query: str,
  83. single_row: bool = False) \
  84. -> Union[pd.DataFrame, Optional[pd.Series]]:
  85. try:
  86. self.cursor.execute(query)
  87. df = pd.DataFrame(columns=self.cursor.column_names, data=self.cursor.fetchall())
  88. if single_row:
  89. return df.iloc[0] if len(df) > 0 else None
  90. return df
  91. except (self.BACKEND.connector.InterfaceError,
  92. self.BACKEND.connector.OperationalError):
  93. self.reconnect()
  94. return self.read(query, single_row)
  95. except Exception as e:
  96. raise Exception(f'MySQL raised an error while reading: {e.args}. SQL: {query}')
  97. def commit(self) -> None:
  98. self.conn.commit()
  99. def close(self) -> None:
  100. self.OUTPUT(f'Closing connection to {self.host} -> {self.scheme}', **self.OUTPUT_KWARGS)
  101. self.conn.close()
  102. def reconnect(self) -> None:
  103. self.OUTPUT(f'Lost connection to database {self.scheme}. Reconnecting...', **self.OUTPUT_KWARGS)
  104. self.conn.reconnect(20, 5)
  105. self.OUTPUT(f'Reconnect successful', **self.OUTPUT_KWARGS)
  106. def change_scheme(self, scheme) -> None:
  107. self.close()
  108. self.__init__(host=self.host, user=self.user, password=self.password, scheme=scheme)
  109. @staticmethod
  110. def _parse_value(value) -> str:
  111. if value is None:
  112. return 'NULL'
  113. elif type(value) is float:
  114. return f'{value:12f}'
  115. elif type(value) is int:
  116. return f'{value}'
  117. elif type(value) is str and value.startswith('~') and value.endswith('~'):
  118. return f'{value.strip("~")}'
  119. else:
  120. return f'\"{str(value)}\"'