| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146 |
- from typing import Dict, Hashable, Iterable, List, MutableMapping, Union, Tuple, Optional, Any
- from .dbi import DBI, pd
- class MySQL(DBI):
- def __init__(self,
- host: Optional[str] = None,
- user: Optional[str] = None,
- password: Optional[str] = None,
- scheme: Optional[str] = None):
- super().__init__(host, user, password, scheme)
- import mysql.connector
- self.BACKEND = mysql
- self.conn = self.BACKEND.connector.connect(host=self.host,
- user=self.user,
- password=self.password,
- database=self.scheme)
- self.cursor = self.conn.cursor()
- def write(self,
- query: str) \
- -> DBI.WriteReturn:
- try:
- self.cursor.execute(query)
- self.conn.commit()
- return DBI.WriteReturn(rows=self.cursor.rowcount,
- key=self.cursor.lastrowid)
- except (self.BACKEND.connector.InterfaceError,
- self.BACKEND.connector.OperationalError):
- self.reconnect()
- return self.write(query)
- except Exception as e:
- raise Exception(f'MySQL raised an error while writing: {e.args}. SQL: {query}')
- def insert(self,
- table: str,
- rows: Union[pd.DataFrame, Dict[str, Any], List[Dict[str, Any]]],
- ignore: bool = False,
- update: bool = False,
- schema: Optional[str] = None) \
- -> DBI.WriteReturn:
- sql: str = f'INSERT {"IGNORE" if ignore else ""} INTO {"`"+schema+"`." if schema else ""}`{table}` '
- prep_rows: List[Dict[str, Any]] = []
- if type(rows) is dict:
- columns: Iterable[str] = rows.keys()
- prep_rows = [rows]
- elif type(rows) is pd.DataFrame:
- columns = rows.columns
- prep_rows = rows.to_dict(orient='records', index=False) # type: ignore
- elif type(rows) is list:
- columns = rows[0].keys()
- prep_rows = rows
- else:
- raise ValueError(f'Cannot extract columns from type "{type(rows).__qualname__}"')
-
- sql += f'({",".join([f"`{c}`" for c in columns])}) VALUES '
- # Create each row entry
- _data: List[str] = []
- for row in prep_rows:
- _row: List[str] = []
- for column in columns:
- _col = row.get(column)
- _row.append(self._parse_value(_col))
- _data.append(f'({",".join(_row)})')
- # Merge row entries to single entry
- sql += f'{",".join(_data)} AS `new` '
- if update:
- sql += f'ON DUPLICATE KEY UPDATE {",".join([f"`{c}`=`new`.`{c}`" for c in columns])}'
- return self.write(sql)
- def update(self,
- table: str,
- ids: Dict[str, Any],
- values: Dict[str, Any],
- schema: Optional[str] = None) \
- -> DBI.WriteReturn:
- sql: str = f'UPDATE {"`"+schema+"`." if schema else ""}`{table}` '
- _val: List[str] = []
- for key in values.keys():
- _val.append(f'`{key}`={self._parse_value(values.get(key))}')
- sql += f'SET {",".join(_val)} '
- _id: List[str] = []
- for key in ids.keys():
- _id.append(f'`{key}`={self._parse_value(ids.get(key))}')
- sql += f'WHERE {" AND ".join(_id)}'
- return self.write(sql)
- def read(self,
- query: str,
- single_row: bool = False) \
- -> Union[pd.DataFrame, Optional[pd.Series]]:
- try:
- self.cursor.execute(query)
- df = pd.DataFrame(columns=self.cursor.column_names, data=self.cursor.fetchall())
- if single_row:
- return df.iloc[0] if len(df) > 0 else None
- return df
- except (self.BACKEND.connector.InterfaceError,
- self.BACKEND.connector.OperationalError):
- self.reconnect()
- return self.read(query, single_row)
- except Exception as e:
- raise Exception(f'MySQL raised an error while reading: {e.args}. SQL: {query}')
- def commit(self) -> None:
- self.conn.commit()
- def close(self) -> None:
- self.OUTPUT(f'Closing connection to {self.host} -> {self.scheme}', **self.OUTPUT_KWARGS)
- self.conn.close()
- def reconnect(self) -> None:
- self.OUTPUT(f'Lost connection to database {self.scheme}. Reconnecting...', **self.OUTPUT_KWARGS)
- self.conn.reconnect(20, 5)
- self.OUTPUT(f'Reconnect successful', **self.OUTPUT_KWARGS)
- def change_scheme(self, scheme) -> None:
- self.close()
- self.__init__(host=self.host, user=self.user, password=self.password, scheme=scheme)
- @staticmethod
- def _parse_value(value) -> str:
- if value is None:
- return 'NULL'
- elif type(value) is float:
- return f'{value:12f}'
- elif type(value) is int:
- return f'{value}'
- elif type(value) is str and value.startswith('~') and value.endswith('~'):
- return f'{value.strip("~")}'
- else:
- return f'\"{str(value)}\"'
|