import csv
from enum import Enum
from typing import List, Optional


def cut_str(text: str, cut: str) -> str:
    """
    文字列カット関数
    Args:
        text (str): カットしたい文字を含む文字列
        cut (str): カットする文字列
    Returns:
        str: カット後の文字列
    """
    pos = text.find(cut)
    if 0 <= pos:
        return text[pos + len(cut) : len(text)]
    return text


class TABLE_TYPE(Enum):
    """
    テーブルタイプ（enum）
    """

    ROW = 1
    KEY = 2
    KEY_UNIT = 3
    NONE = -1


class Ss7CsvInputTable:
    """
    テーブルクラス
    """

    def __init__(self) -> None:
        """
        コンストラクタ
        """
        self.name: str = ""
        self.case: str = ""
        self.tab: str = ""
        self.group: str = ""
        self.type: TABLE_TYPE = TABLE_TYPE.NONE
        self.fields: List[List[str]] = []
        self.Units: List[str] = []
        self.records: List[List[str]] = []

    def set_name(self, name: str, case: str, tab: str, group: str) -> None:
        """
        テーブル名の登録
        Args:
            name (str): テーブル名（name=）
            case (str): ケース名（case=）
            tab (str): タブ名（tab=）
            group (str): グループ名（group=）
        """
        self.name = name
        self.case = case
        self.tab = tab
        self.group = group
        self.type = TABLE_TYPE.NONE

    def set_csv_unit(self, data: [str]) -> None:
        """
        単位定義の登録
        Args:
            data (str]): CSV一行分のデータ
        """
        self.Units = data

    def set_csv_field(self, data: [str]) -> None:
        """
        列名の登録
        Args:
            data (str]): CSV一行分のデータ
        """
        self.fields.append(data)

    def set_csv_data(self, data: [str]) -> None:
        """
        値の登録(row型)
        Args:
            data (str]): CSV一行分のデータ
        """
        self.type = TABLE_TYPE.ROW
        self.records.append(data)

    def set_csv_key_data(self, data: [str]) -> None:
        """
        値の登録(key型単位なし)
        Args:
            data (str]): CSV一行分のデータ
        """
        self.type = TABLE_TYPE.KEY
        self.records.append([data[-1]])
        self.fields.append(data)

    def set_csv_key_unit_data(self, data: [str]) -> None:
        """
        値の登録(key型単位あり)
        Args:
            data (str]): CSV一行分のデータ
        """
        self.type = TABLE_TYPE.KEY_UNIT
        self.records.append([data[-1]])
        self.Units.append(data[-2])
        self.fields.append(data)

    def output(self, data: [str]) -> None:
        """
        出力
        Args:
            data (str]): 出力先
        """

        # テーブル名
        names = []
        names.append("name=" + self.name)
        if not self.case == "":
            names.append("case=" + self.case)
        if not self.tab == "":
            names.append("tab=" + self.tab)
        if not self.group == "":
            names.append("group=" + self.group)
        data.append(names)

        # rowタイプ
        if self.type == TABLE_TYPE.ROW:
            for field in self.fields:
                data.append(field)
            if not len(self.Units) == 0:
                data.append(["<unit>"])
                data.append(self.Units)
            data.append(["<data>"])
            for rec in self.records:
                data.append(rec)

        # keyタイプ
        elif self.type == TABLE_TYPE.KEY:
            def_row = []
            for i in range(len(self.fields[0])):
                def_row.append("")
            def_row[-1] = "<data>"
            data.append(def_row)
            for field, rec in zip(self.fields, self.records):
                field[-1] = rec[0]
                data.append(field)

        # key(単位付)タイプ
        elif self.type == TABLE_TYPE.KEY_UNIT:
            def_row = []
            for i in range(len(self.fields[0])):
                def_row.append("")
            def_row[-2] = "<unit>"
            def_row[-1] = "<data>"
            data.append(def_row)
            for field, rec in zip(self.fields, self.records):
                field[-1] = rec[0]
                data.append(field)

    def get_row_count(self) -> int:
        """
        データの行数の取得
        Returns:
            int: 行数
        """
        return len(self.records)

    def get_col_count(self) -> int:
        """
        データの列数の取得
        Returns:
            int: 列数
        """
        return len(self.records[0])

    def get_data(self, row: int, col: int) -> str:
        """
        値の取得
        Args:
            row (int): 行インデックス（0～）
            col (int): 列インデックス（0～）
        Returns:
            str: 値
        """
        if self.type == TABLE_TYPE.ROW:
            if row < len(self.records):
                if col < len(self.records[row]):
                    return self.records[row][col]

        if self.type == TABLE_TYPE.KEY or self.type == TABLE_TYPE.KEY_UNIT:
            if row == 0 and col < len(self.records):
                return self.records[col][row]
        return ""

    def get_data_key1(self, key: str, col: int) -> str:
        """
        値の取得
        Args:
            key (str): キー文字列
            col (int): 列インデックス（0～）
        Returns:
            str: キーで検索した値
        """
        if self.type == TABLE_TYPE.ROW:
            if col < self.get_col_count():
                for rec in self.records:
                    if rec[0] == key:
                        if col < len(rec):
                            return rec[col]
        return ""

    def get_data_key2(self, key1: str, key2: str, col: int) -> str:
        """
        値の取得
        Args:
            key1～ (str): キー文字列
            col (int): 列インデックス（0～）
        Returns:
            str: キーで検索した値
        """
        if self.type == TABLE_TYPE.ROW:
            if col < self.get_col_count():
                for rec in self.records:
                    if rec[0] == key1:
                        if rec[1] == key2:
                            if col < len(rec):
                                return rec[col]
        return ""

    def get_data_key3(self, key1: str, key2: str, key3: str, col: int) -> str:
        """
        値の取得
        Args:
            key1～ (str): キー文字列
            col (int): 列インデックス（0～）
        Returns:
            str: キーで検索した値
        """
        if self.type == TABLE_TYPE.ROW:
            if col < self.get_col_count():
                for rec in self.records:
                    if rec[0] == key1:
                        if rec[1] == key2:
                            if rec[2] == key3:
                                if col < len(rec):
                                    return rec[col]
        return ""

    def set_data(self, row: int, col: int, data: str) -> int:
        """
        値の登録
        Args:
            row (int): 行インデックス（0～）
            col (int): 列インデックス（0～）
            data(str): 登録する値
        Returns:
            int: 0=成功 -1=失敗
        """
        if self.type == TABLE_TYPE.ROW:
            if row < len(self.records):
                if col < len(self.records[row]):
                    self.records[row][col] = data
                    return 0
        if self.type == TABLE_TYPE.KEY or self.type == TABLE_TYPE.KEY_UNIT:
            if row == 0 and col < len(self.records):
                self.records[col][row] = data
                return 0

        return -1

    def set_data_key1(self, key: str, col: int, data: str) -> int:
        """
        値の登録
        Args:
            key (str): キー文字列
            col (int): 列インデックス（0～）
            data(str): 登録する値
        Returns:
            int: 0=成功 -1=失敗
        """
        if self.type == TABLE_TYPE.ROW:
            if col < self.get_col_count():
                for rec in self.records:
                    if rec[0] == key:
                        if col < len(rec):
                            rec[col] = data
                            return 0
        return -1

    def set_data_key2(self, key1: str, key2: str, col: int, data: str) -> int:
        """
        値の登録
        Args:
            key1～ (str): キー文字列
            col (int): 列インデックス（0～）
            data(str): 登録する値
        Returns:
            int: 0=成功 -1=失敗
        """
        if self.type == TABLE_TYPE.ROW:
            if col < self.get_col_count():
                for rec in self.records:
                    if rec[0] == key1:
                        if rec[1] == key2:
                            if col < len(rec):
                                rec[col] = data
                                return 0
        return -1

    def set_data_key3(
        self, key1: str, key2: str, key3: str, col: int, data: str
    ) -> int:
        """
        値の登録
        Args:
            key1～ (str): キー文字列
            col (int): 列インデックス（0～）
            data(str): 登録する値
        Returns:
            int: 0=成功 -1=失敗
        """
        if self.type == TABLE_TYPE.ROW:
            if col < self.get_col_count():
                for rec in self.records:
                    if rec[0] == key1:
                        if rec[1] == key2:
                            if rec[2] == key3:
                                if col < len(rec):
                                    rec[col] = data
                                    return 0
        return -1

    def search_col1(self, field: str) -> int:
        """
        列の検索
        Args:
            field (str): 列名
        Returns:
            int: 列インデックス（0～）
        """
        # rowタイプ
        if self.type == TABLE_TYPE.ROW:
            col = 0
            if 1 <= len(self.fields[0]):
                for self_field in self.fields[0]:
                    if self_field == field:
                        return col
                    col = col + 1

        # keyタイプ
        elif self.type == TABLE_TYPE.KEY or self.type == TABLE_TYPE.KEY_UNIT:
            col = 0
            for self_fields in self.fields:
                if self_fields[0] == field:
                    return col
                col = col + 1

        return -1

    def search_col2(self, field1: str, field2: str) -> int:
        """
        列の検索
        Args:
            field1 (str): 列名（上段）
            field2 (str): 列名（下段）
        Returns:
            int: 列インデックス（0～）
        """
        # rowタイプ
        if self.type == TABLE_TYPE.ROW:
            col = 0
            self_field1_pre = ""
            if 2 <= len(self.fields):
                for self_field1, self_field2 in zip(self.fields[0], self.fields[1]):
                    # 上段列名が空白の場合は直前の列名を引き継ぐ
                    if self_field1 == "":
                        self_field1 = self_field1_pre

                    if self_field1 == field1 and self_field2 == field2:
                        return col

                    col = col + 1
                    self_field1_pre = self_field1
        # keyタイプ
        elif self.type == TABLE_TYPE.KEY or self.type == TABLE_TYPE.KEY_UNIT:
            col = 0
            for self_fields in self.fields:
                self_field1 = self_fields[0]
                self_field2 = self_fields[1]
                if self_field1 == "":
                    self_field1 = self_field1_pre

                if self_field1 == field1 and self_field2 == field2:
                    return col

                col = col + 1
                self_field1_pre = self_field1

        return -1

    def search_col3(self, field1: str, field2: str, field3: str) -> int:
        """
        列の検索
        Args:
            field1 (str): 列名（上段）
            field2 (str): 列名（中段）
            field3 (str): 列名（下段）
        Returns:
            int: 列インデックス（0～）
        """
        # rowタイプ
        if self.type == TABLE_TYPE.ROW:
            col = 0
            self_field1_pre = ""
            self_field2_pre = ""
            if 3 <= len(self.fields):
                for self_field1, self_field2, self_field3 in zip(
                    self.fields[0], self.fields[1], self.fields[2]
                ):
                    # 上段列名が空白の場合は直前の列名を引き継ぐ
                    if self_field1 == "":
                        self_field1 = self_field1_pre
                    if self_field2 == "":
                        self_field2 = self_field2_pre

                    if (
                        self_field1 == field1
                        and self_field2 == field2
                        and self_field3 == field3
                    ):
                        return col

                    col = col + 1
                    self_field1_pre = self_field1
                    self_field2_pre = self_field2
        # keyタイプ
        elif self.type == TABLE_TYPE.KEY or self.type == TABLE_TYPE.KEY_UNIT:
            col = 0
            self_field1_pre = ""
            self_field2_pre = ""
            for self_fields in self.fields:
                self_field1 = self_fields[0]
                self_field2 = self_fields[1]
                self_field3 = self_fields[2]
                if self_field1 == "":
                    self_field1 = self_field1_pre
                if self_field2 == "":
                    self_field2 = self_field2_pre

                if (
                    self_field1 == field1
                    and self_field2 == field2
                    and self_field3 == field3
                ):
                    return col

                col = col + 1
                self_field1_pre = self_field1
                self_field2_pre = self_field2

        return -1

    def search_col4(self, field1: str, field2: str, field3: str, field4: str) -> int:
        """
        列の検索
        Args:
            field1 (str): 列名（1段目）
            field2 (str): 列名（2段目）
            field3 (str): 列名（3段目）
            field4 (str): 列名（4段目）
        Returns:
            int: 列インデックス（0～）
        """
        # rowタイプ
        if self.type == TABLE_TYPE.ROW:
            col = 0
            self_field1_pre = ""
            self_field2_pre = ""
            self_field3_pre = ""
            if 3 <= len(self.fields):
                for self_field1, self_field2, self_field3, self_field4 in zip(
                    self.fields[0], self.fields[1], self.fields[2], self.fields[3]
                ):
                    # 上段列名が空白の場合は直前の列名を引き継ぐ
                    if self_field1 == "":
                        self_field1 = self_field1_pre
                    if self_field2 == "":
                        self_field2 = self_field2_pre
                    if self_field3 == "":
                        self_field3 = self_field3_pre

                    if (
                        self_field1 == field1
                        and self_field2 == field2
                        and self_field3 == field3
                        and self_field4 == field4
                    ):
                        return col

                    col = col + 1
                    self_field1_pre = self_field1
                    self_field2_pre = self_field2
                    self_field3_pre = self_field3
        # keyタイプ
        elif self.type == TABLE_TYPE.KEY or self.type == TABLE_TYPE.KEY_UNIT:
            col = 0
            self_field1_pre = ""
            self_field2_pre = ""
            self_field3_pre = ""
            for self_fields in self.fields:
                self_field1 = self_fields[0]
                self_field2 = self_fields[1]
                self_field3 = self_fields[2]
                self_field4 = self_fields[3]
                if self_field1 == "":
                    self_field1 = self_field1_pre
                if self_field2 == "":
                    self_field2 = self_field2_pre
                if self_field3 == "":
                    self_field3 = self_field3_pre

                if (
                    self_field1 == field1
                    and self_field2 == field2
                    and self_field3 == field3
                    and self_field4 == field4
                ):
                    return col

                col = col + 1
                self_field1_pre = self_field1
                self_field2_pre = self_field2
                self_field3_pre = self_field3

        return -1


class Ss7CsvInput:
    """
    入力CSV読み込みクラス
    """

    def __init__(self, file_path: str):
        """
        コンストラクタ
        Args:
            file_path (str): CSVファイルのパス
        """
        self.path = file_path
        self.headers = []
        self.tables = []

        with open(file_path, mode="r", encoding="cp932") as File:
            self.reader = csv.reader(File)
            bHeader = True
            bTable = False
            bUnit = False
            bData = False
            bKeyUnitData = False
            bKeyData = False
            for row_data in self.reader:
                if 0 < len(row_data):
                    if 0 <= row_data[0].find("name="):
                        bHeader = False
                        bTable = True
                        bUnit = False
                        bData = False
                        name = cut_str(row_data[0], "name=")
                        case = ""
                        if 1 < len(row_data) and 0 <= row_data[1].find("case="):
                            case = cut_str(row_data[1], "case=")

                        tab = ""
                        if 1 < len(row_data) and 0 <= row_data[1].find("tab="):
                            tab = cut_str(row_data[1], "tab=")

                        group = ""
                        if 2 < len(row_data) and 0 <= row_data[2].find("group="):
                            group = cut_str(row_data[2], "group=")

                        self.tables.append(Ss7CsvInputTable())
                        table = self.tables[-1]
                        table.set_name(name, case, tab, group)
                        continue

                    elif (
                        2 <= len(row_data)
                        and 0 <= row_data[-1].find("<data>")
                        and 0 <= row_data[-2].find("<unit>")
                    ):
                        bKeyUnitData = True
                        continue

                    elif len(row_data) == 1 and 0 <= row_data[0].find("<unit>"):
                        bUnit = True
                        bData = False
                        continue

                    elif len(row_data) == 1 and 0 <= row_data[0].find("<data>"):
                        bData = True
                        bUnit = False
                        continue

                    elif 0 <= row_data[-1].find("<data>"):
                        bKeyData = True
                        continue

                    if bTable:
                        if bUnit:
                            table.set_csv_unit(row_data)
                        elif bData:
                            table.set_csv_data(row_data)
                        elif bKeyData:
                            table.set_csv_key_data(row_data)
                        elif bKeyUnitData:
                            table.set_csv_key_unit_data(row_data)
                        else:
                            table.set_csv_field(row_data)

                    if bHeader:
                        self.headers.append(row_data)

                else:
                    bTable = False
                    bUnit = False
                    bData = False
                    bKeyUnitData = False
                    bKeyData = False

    def write(self, file_path: str) -> None:
        """
        ファイル出力
        Args:
            file_path (str): CSVファイルパス
        """
        data = []

        # ヘッダー
        for header in self.headers:
            data.append(header)
        data.append([])  # 間に空行を入れる

        # テーブル
        for table in self.tables:
            table.output(data)
            data.append([])  # 間に空行を入れる

        with open(file_path, mode="w", newline="", encoding="cp932") as File:
            writer = csv.writer(File)
            writer.writerows(data)

    def get_all_tables(self) -> [Ss7CsvInputTable]:
        """
        全テーブルの取得
        Returns:
            [Ss7CsvInputTable]: テーブルリスト
        """
        return self.tables

    def search_table(self, table_name: str) -> Optional[Ss7CsvInputTable]:
        """
        テーブルの検索
        Args:
            table_name (str): テーブル名（name=）
        Returns:
            Optional[Ss7CsvInputTable]: テーブル（見つからなかった場合はNone）
        """
        for table in self.tables:
            if table.name == table_name:
                return table
        return None

    def search_table_case(
        self, table_name: str, case_or_tab_name: str, group_name: str = ""
    ) -> Optional[Ss7CsvInputTable]:
        """
        テーブルの検索
        Args:
            table_name (str): テーブル名（name=）
            case_or_tab_name (str): ケース名（case=）もしくはタブ名（tab=）
            group_name (str): グループ名（group=）（省略可能）
        Returns:
            Optional[Ss7CsvInputTable]: テーブル（見つからなかった場合はNone）
        """
        for table in self.tables:
            if table.name == table_name:
                if table.case == case_or_tab_name:
                    if table.group == group_name:
                        return table
                elif table.tab == case_or_tab_name:
                    if table.group == group_name:
                        return table
        return None
