from typing import Dict, Hashable, Iterable, List, MutableMapping, Union, Tuple, Optional, Any from .dbi import DBI, pd class MySQL(DBI): def __init__(self, host: Optional[str] = None, user: Optional[str] = None, password: Optional[str] = None, scheme: Optional[str] = None): super().__init__(host, user, password, scheme) import mysql.connector self.BACKEND = mysql self.conn = self.BACKEND.connector.connect(host=self.host, user=self.user, password=self.password, database=self.scheme) self.cursor = self.conn.cursor() def write(self, query: str) \ -> DBI.WriteReturn: try: self.cursor.execute(query) self.conn.commit() return DBI.WriteReturn(rows=self.cursor.rowcount, key=self.cursor.lastrowid) except (self.BACKEND.connector.InterfaceError, self.BACKEND.connector.OperationalError): self.reconnect() return self.write(query) except Exception as e: raise Exception(f'MySQL raised an error while writing: {e.args}. SQL: {query}') def insert(self, table: str, rows: Union[pd.DataFrame, Dict[str, Any], List[Dict[str, Any]]], ignore: bool = False, update: bool = False, schema: Optional[str] = None) \ -> DBI.WriteReturn: sql: str = f'INSERT {"IGNORE" if ignore else ""} INTO {"`"+schema+"`." if schema else ""}`{table}` ' prep_rows: List[Dict[str, Any]] = [] if type(rows) is dict: columns: Iterable[str] = rows.keys() prep_rows = [rows] elif type(rows) is pd.DataFrame: columns = rows.columns prep_rows = rows.to_dict(orient='records', index=False) # type: ignore elif type(rows) is list: columns = rows[0].keys() prep_rows = rows else: raise ValueError(f'Cannot extract columns from type "{type(rows).__qualname__}"') sql += f'({",".join([f"`{c}`" for c in columns])}) VALUES ' # Create each row entry _data: List[str] = [] for row in prep_rows: _row: List[str] = [] for column in columns: _col = row.get(column) _row.append(self._parse_value(_col)) _data.append(f'({",".join(_row)})') # Merge row entries to single entry sql += f'{",".join(_data)} AS `new` ' if update: sql += f'ON DUPLICATE KEY UPDATE {",".join([f"`{c}`=`new`.`{c}`" for c in columns])}' return self.write(sql) def update(self, table: str, ids: Dict[str, Any], values: Dict[str, Any], schema: Optional[str] = None) \ -> DBI.WriteReturn: sql: str = f'UPDATE {"`"+schema+"`." if schema else ""}`{table}` ' _val: List[str] = [] for key in values.keys(): _val.append(f'`{key}`={self._parse_value(values.get(key))}') sql += f'SET {",".join(_val)} ' _id: List[str] = [] for key in ids.keys(): _id.append(f'`{key}`={self._parse_value(ids.get(key))}') sql += f'WHERE {" AND ".join(_id)}' return self.write(sql) def read(self, query: str, single_row: bool = False) \ -> Union[pd.DataFrame, Optional[pd.Series]]: try: self.cursor.execute(query) df = pd.DataFrame(columns=self.cursor.column_names, data=self.cursor.fetchall()) if single_row: return df.iloc[0] if len(df) > 0 else None return df except (self.BACKEND.connector.InterfaceError, self.BACKEND.connector.OperationalError): self.reconnect() return self.read(query, single_row) except Exception as e: raise Exception(f'MySQL raised an error while reading: {e.args}. SQL: {query}') def commit(self) -> None: self.conn.commit() def close(self) -> None: self.OUTPUT(f'Closing connection to {self.host} -> {self.scheme}', **self.OUTPUT_KWARGS) self.conn.close() def reconnect(self) -> None: self.OUTPUT(f'Lost connection to database {self.scheme}. Reconnecting...', **self.OUTPUT_KWARGS) self.conn.reconnect(20, 5) self.OUTPUT(f'Reconnect successful', **self.OUTPUT_KWARGS) def change_scheme(self, scheme) -> None: self.close() self.__init__(host=self.host, user=self.user, password=self.password, scheme=scheme) @staticmethod def _parse_value(value) -> str: if value is None: return 'NULL' elif type(value) is float: return f'{value:12f}' elif type(value) is int: return f'{value}' elif type(value) is str and value.startswith('~') and value.endswith('~'): return f'{value.strip("~")}' else: return f'\"{str(value)}\"'