#!/usr/bin/env python
import logging
import os
import sqlite3
import json
import operator
import sys

from functools import cached_property
from typing import List
from typing import Optional
from dataclasses import dataclass
import nsys_cpu_stats.trace_utils as tu

from nsys_cpu_stats.trace_loader import TraceLoaderInterface, TraceLoaderSupport, TraceLoaderRegions, TraceLoaderGPUMetrics, TraceLoaderEvents
from nsys_cpu_stats.trace_utils import FrameDurations, Frames, CallStack, CallStackFrame, CallStackType, TimeSlice, GPUMetric

logger = logging.getLogger(__name__)

@dataclass
class EventMarker:
    id: int = -1
    start: int = -1
    end: int = -1
    correlation_id: int = -1
    name_id: int = -1
    tid: int = -1
    pid: int = -1


class NSysRepLoader(TraceLoaderInterface):
    ####################################################
    #
    # Init/Cleanup code
    #
    ####################################################
    def __init__(self):
        super().__init__()
        self.__connection: Optional[sqlite3.Connection] = None
        self.table_name_list: Optional[list] = None
        self.quiet = True
        # DTSP-17586 - this is required to support both SQL reports from NSys version
        # 2024.5.1 as well as older versions
        self.nsys_schema_version_major = 0
        self.nsys_schema_version_minor = 0

    def init_database(self, file_in: str) -> bool:
        """Initialise the NSys Loader with an SQLite database"""
        # Open the connection
        if not os.path.exists(file_in):
            logger.error(f"File '{file_in}' does not exist")
            return False

        self.close_database()
        self.__connection = sqlite3.connect(file_in)
        self.__connection.text_factory = lambda b: b.decode(errors='ignore')  # Ignore utf errors : OperationalError: Could not decode to UTF-8 column 'value' with text '
        self.__connection.row_factory = sqlite3.Row
        self.__init_common()
        return True

    def close_database(self):
        """Close the current database"""
        if self.__connection is not None:
            self.__connection.close()
            self.__connection = None

    ####################################################
    #
    # Private helpers
    #
    ####################################################
    def __init_common(self):
        self.table_name_list = self.__init_table_name_list()
        self.__detect_sqlite_schema_version()
        self.string_ids = self.__init_string_table()

        # Do we want verbose printing
        self.verbose = False
        self.nvtx_domain_dict = self.__get_nvtx_domain_dictionary()

    def __get_connection(self):
        return self.__connection

    def __get_cursor(self):
        return self.__connection.cursor()

    def __init_string_table(self):
        cursor = self.__connection.cursor()
        cursor.execute('select id,value from StringIds')
        string_ids = cursor.fetchall()
        return string_ids

    def __init_table_name_list(self):
        """Capture a list of table names from the sql db"""
        table_name_list = []
        for rows in self.__connection.execute("SELECT name FROM sqlite_master WHERE type='table';"):
            table_name_list.append(rows['name'])

        # urgh - add task names as a new table entry
        if 'ETW_EVENTS' in table_name_list:
            cursor = self.__connection.execute('SELECT * FROM ETW_EVENTS')
            table_column_names = list(map(lambda x: x[0], cursor.description))
            if 'taskName' in table_column_names:
                table_name_list.append('ETW_EVENT_TASK_NAMES')

        if 'ETW_EVENTS_DEPRECATED_TABLE' in table_name_list:
            cursor = self.__connection.execute('SELECT * FROM ETW_EVENTS_DEPRECATED_TABLE')
            table_column_names = list(map(lambda x: x[0], cursor.description))
            if 'taskName' in table_column_names:
                table_name_list.append('ETW_EVENT_TASK_NAMES')

        return table_name_list

    def __detect_sqlite_schema_version(self):
        if not self.table_name_list or 'META_DATA_EXPORT' not in self.table_name_list:
            logger.debug("Could not determine Nsight Systems version used to export SQL database")
            return
        row = self.__connection.execute("SELECT value FROM META_DATA_EXPORT WHERE name='EXPORT_SCHEMA_VERSION_MAJOR' LIMIT 1;").fetchone()
        if row:
            self.nsys_schema_version_major = int(row['value'])
        row = self.__connection.execute("SELECT value FROM META_DATA_EXPORT WHERE name='EXPORT_SCHEMA_VERSION_MINOR' LIMIT 1;").fetchone()
        if row:
            self.nsys_schema_version_minor = int(row['value'])
        if self.__does_etw_table_contain_event_meta:
            logger.debug("SQL database was exported by a legacy version of Nsight Systems [<= 2024.4]")

    @cached_property
    def __does_etw_table_contain_event_meta(self):
        # The SQLite schema update in Nsight Systems that changed ETW_EVENTS to use
        # generic events was in version 3.13:
        old_major = self.nsys_schema_version_major < 3
        old_minor = (self.nsys_schema_version_major == 3) and (self.nsys_schema_version_minor < 13)
        return old_major or old_minor

    def __get_etw_events_sql_table_for_query(self):
        if self.__does_etw_table_contain_event_meta:
            return 'ETW_EVENTS'
        return '''ETW_EVENTS JOIN GENERIC_EVENT_TYPES on ETW_EVENTS.typeId=GENERIC_EVENT_TYPES.typeId'''

    def __etw_task_id_column_name(self):
        return 'taskId' if self.__does_etw_table_contain_event_meta else 'etwTaskId'

    def __etw_provider_id_column_name(self):
        return 'providerId' if self.__does_etw_table_contain_event_meta else 'etwProviderId'

    def __etw_event_id_column_name(self):
        return 'eventId' if self.__does_etw_table_contain_event_meta else 'etwEventId'

    def __get_pid_tid_from_row(self, row):
        if self.__does_etw_table_contain_event_meta:
            return (row['processId'] if ('processId' in row.keys()) else None,
                    row['threadId'] if ('threadId' in row.keys()) else None)
        return tu.convert_global_tid(row['globalTid'])

    def __supports_table(self, name):
        return name in self.table_name_list

    def __does_sql_column_name_exist_in_table(self, column_name, table_name):
        cursor = self.__connection.cursor()
        cursor.execute(f'PRAGMA table_info({table_name})')
        columns = cursor.fetchall()
        return any(column[1] == column_name for column in columns)

    ####################################################
    # Are CPU frametimes available
    ####################################################
    def __are_cpu_frame_times_available(self,
                                        start_time_ns: Optional[int] = None,
                                        end_time_ns: Optional[int] = None,
                                        target_pid: Optional[int] = None) -> bool:
        """Determine if any Present() calls/events are available."""
        # Todo - instead of 'Select', can we just count the number of items via a query?

        # Several different ways to get the CPU frame times, all based on capturing present information
        # 1. ETW event task names - this is old and deprecated, but we keep it in case we see any old traces
        if self.__supports_table('ETW_EVENT_TASK_NAMES'):
            # The old deprecated way
            if self.__supports_table('ETW_EVENTS_DEPRECATED_TABLE'):
                rows = self.__get_connection().execute('select * from ETW_EVENTS_DEPRECATED_TABLE where taskName=?', (self.get_string_id('Present'),))
            else:
                rows = self.__get_connection().execute('select * from ETW_EVENTS where taskName=?', (self.get_string_id('Present'),))

            # return True if there is a row
            for row in rows:
                return True

        # 2. Standard ETW events
        if self.__supports_table('ETW_TASKS'):

            # This is the way
            presentTaskID = -1
            for row in self.__get_connection().execute('select * from ETW_TASKS where taskNameId=?', (self.get_string_id('Present'),)):
                presentTaskID = row['taskId']

            if presentTaskID < 0:
                return False

            # Get the list of frames and frame start times
            rows = None
            if target_pid:
                pid_column_name = 'processId' if self.__does_etw_table_contain_event_meta else 'globalTid'
                query = 'select * from ' + self.__get_etw_events_sql_table_for_query()
                query += ' where ' + self.__etw_task_id_column_name() + '=? and '
                query += pid_column_name + '=?'
                rows = self.__get_connection().execute(query, (presentTaskID, target_pid,))
            else:
                query = 'select * from ' + self.__get_etw_events_sql_table_for_query()
                query += ' where ' + self.__etw_task_id_column_name() + '=?'
                rows = self.__get_connection().execute(query, (presentTaskID,))
            for row in rows:
                if start_time_ns and row['timestamp'] < start_time_ns:
                    continue
                if end_time_ns and row['timestamp'] > end_time_ns:
                    continue
                return True

        # 3. Look in the WDDM queue packets for the present packet
        if self.__supports_table('WDDM_QUEUE_PACKET_START_EVENTS'):
            rows = self.__get_connection().execute('SELECT start, end, globalTid, present FROM WDDM_QUEUE_PACKET_START_EVENTS WHERE present=1 ORDER BY start ASC')
            for row in rows:
                if start_time_ns and row['end'] < start_time_ns:
                    continue
                if end_time_ns and row['start'] > end_time_ns:
                    continue
                if target_pid and tu.get_pid(row['globalTid']) != target_pid:
                    continue
                return True

        return False

    ####################################################
    # Get hot interrupts/ranges from ETW events for the given time period and tid/pid
    ####################################################
    def __are_dxgkrnl_profile_ranges_available(self) -> bool:
        """Determine if the DxgKrnl ETW events are available"""
        if not self.__supports_table('ETW_EVENTS') or not self.__supports_table('ETW_TASKS') or not self.__supports_table('ETW_PROVIDERS'):
            if not self.quiet:
                logger.debug("ETW_EVENTS/ETW_TASKS/ETW_PROVIDERS not supported")
            return False

        for row in self.__get_connection().execute('select * from ETW_PROVIDERS where providerNameId=?', (self.get_string_id('Microsoft-Windows-DxgKrnl'),)):
            return True
        return False

    ####################################################
    # Are GPU metrics supported
    ####################################################
    def __are_gpu_metrics_available(self) -> bool:
        """Are any GPU metrics available in this report"""
        if self.nsys_schema_version_major < 3 or (self.nsys_schema_version_major == 3 and self.nsys_schema_version_minor <= 14):
            if not (self.__supports_table('GENERIC_EVENT_SOURCES') and self.__supports_table('GENERIC_EVENT_TYPES') and self.__supports_table('GENERIC_EVENTS')):
                logger.warning("GENERIC_EVENT_SOURCES not supported.")
                return False

            # Look for the metric we care about in the SOURCES
            for row in self.__get_connection().execute('''SELECT sourceId, data FROM GENERIC_EVENT_SOURCES'''):
                data_dict = json.loads(row['data'])
                if data_dict["Name"] == "GpuMetrics":
                    return True
        else:
            # Newer schemas use a different mechanism
            if not (self.__supports_table('GPU_METRICS') and self.__supports_table('TARGET_INFO_GPU_METRICS')):
                logger.warning("GPU_METRICS not supported.")
                return False
            return True
        return False

    ####################################################
    # Determine GPU metric names
    ####################################################
    def __determine_gpu_metric_support(self):
        """Interrogate the GPU metrics to determine what is supported and what the metric names are."""
        if self.nsys_schema_version_major < 3 or (self.nsys_schema_version_major == 3 and self.nsys_schema_version_minor <= 14):
            event_source = -1
            self.gpu_metric_names = {}

            if not (self.__supports_table('GENERIC_EVENT_SOURCES') and self.__supports_table('GENERIC_EVENT_TYPES') and self.__supports_table('GENERIC_EVENTS')):
                logger.warning("GENERIC_EVENT_SOURCES not supported.")
                return

            for row in self.__get_connection().execute('''SELECT sourceId, data FROM GENERIC_EVENT_SOURCES'''):
                data_dict = json.loads(row['data'])
                if data_dict["Name"] == "GpuMetrics":
                    event_source = row['sourceId']
                    break

            if event_source == -1:
                logger.warning("metrics not supported.")
                return

            # Double check it in the events
            for row in self.__get_connection().execute('SELECT sourceId, typeId, data FROM GENERIC_EVENT_TYPES where sourceId=?', (event_source,)):
                data_dict = json.loads(row['data'])
                fields = data_dict["Fields"]
                for f in fields:
                    if f["Name"] in ["GR Active", "GR Active [Throughput %]"]:
                        self.gpu_metric_names[TraceLoaderGPUMetrics.GPU_UTILISATION] = f["Name"]
                        self.gpu_metrics_supported.append(TraceLoaderGPUMetrics.GPU_UTILISATION)
                    elif f["Name"] in ["PCIe Write Requests to BAR1", "PCIe Write Requests to BAR1 [Requests]"]:
                        self.gpu_metric_names[TraceLoaderGPUMetrics.PCIE_BAR1_WRITES] = f["Name"]
                        self.gpu_metrics_supported.append(TraceLoaderGPUMetrics.PCIE_BAR1_WRITES)
                    elif f["Name"] in ["PCIe Read Requests to BAR1", "PCIe Read Requests to BAR1 [Requests]"]:
                        self.gpu_metric_names[TraceLoaderGPUMetrics.PCIE_BAR1_READS] = f["Name"]
                        self.gpu_metrics_supported.append(TraceLoaderGPUMetrics.PCIE_BAR1_READS)
        else:
            if not (self.__supports_table('GPU_METRICS') and self.__supports_table('TARGET_INFO_GPU_METRICS')):
                logger.warning("GPU_METRICS not supported.")
                return

            for row in self.__get_connection().execute('SELECT * FROM TARGET_INFO_GPU_METRICS'):
                # logger.debug(f'Found GPU Metrics : {row["metricName"]}')
                if row["metricName"] in ["GR Active", "GR Active [Throughput %]"]:
                    self.gpu_metric_names[TraceLoaderGPUMetrics.GPU_UTILISATION] = row["metricName"]
                    self.gpu_metrics_supported.append(TraceLoaderGPUMetrics.GPU_UTILISATION)
                elif row["metricName"] in ["PCIe Write Requests to BAR1", "PCIe Write Requests to BAR1 [Requests]"]:
                    self.gpu_metric_names[TraceLoaderGPUMetrics.PCIE_BAR1_WRITES] = row["metricName"]
                    self.gpu_metrics_supported.append(TraceLoaderGPUMetrics.PCIE_BAR1_WRITES)
                elif row["metricName"] in ["PCIe Read Requests to BAR1", "PCIe Read Requests to BAR1 [Requests]"]:
                    self.gpu_metric_names[TraceLoaderGPUMetrics.PCIE_BAR1_READS] = row["metricName"]
                    self.gpu_metrics_supported.append(TraceLoaderGPUMetrics.PCIE_BAR1_READS)

    ####################################################
    # Helper to get the nvtx domain names
    ####################################################
    def __get_nvtx_domain_dictionary(self) -> Optional[dict]:
        if not self.__supports_table("NVTX_EVENTS"):
            logger.warning("NVTX_EVENTS sql table not supported")
            return None

        domain_dict = {}
        # Get the domain creation events (eventType = 75) and use that to build a dictionary of domains
        query = 'SELECT * FROM NVTX_EVENTS WHERE eventType == 75 ORDER BY start ASC'
        for row in self.__get_connection().execute(query):
            pid, tid = tu.convert_global_tid(row['globalTid'])
            name_id = row['textId']
            if name_id:
                name = self.get_string(name_id)
            else:
                name = str(row['text'])

            domain_dict[row['domainId']] = name
        return domain_dict

    ####################################################
    #
    # Public Interface
    #
    ####################################################
    ####################################################
    # Determine what features this loader supports
    ####################################################
    def determine_support(self):
        """Determine what features this loader supports"""
        # Clear all support
        self.supported = []
        self.regions_supported = []
        self.derived_regions_supported = []
        self.pipeline_regions_supported = []
        self.gpu_metrics_supported = []
        self.gpu_metric_names = {}
        self.events_supported = []

        if self.__supports_table('SCHED_EVENTS'):
            self.supported.append(TraceLoaderSupport.TIMESLICES)
        else:
            logger.warning("Timeslice information not found.")

        if self.__supports_table('SAMPLING_CALLCHAINS'):
            self.supported.append(TraceLoaderSupport.CALLSTACKS)
            self.supported.append(TraceLoaderSupport.CSWITCH_CALLSTACK_BACK)
        else:
            logger.warning("Sampled callstack information not found.")

        if self.__supports_table('TARGET_INFO_SYSTEM_ENV'):
            self.supported.append(TraceLoaderSupport.CORE_COUNT)
        else:
            logger.warning("Core count information not found.")

        if self.__supports_cpu_config():
            self.supported.append(TraceLoaderSupport.CPU_CONFIG)
        else:
            logger.warning("Core config information not found.")

        if self.__are_cpu_frame_times_available():
            self.regions_supported.append(TraceLoaderRegions.CPU_FRAMETIMES)
        else:
            logger.warning("CPU Frametime information not found.")

        if self.__are_gpu_metrics_available():
            self.supported.append(TraceLoaderSupport.GPU_METRICS)
            self.__determine_gpu_metric_support()
            if TraceLoaderGPUMetrics.GPU_UTILISATION in self.gpu_metrics_supported:
                self.derived_regions_supported.append(TraceLoaderRegions.GPU_FRAMETIMES)
        else:
            logger.warning("GPU Frametimes/Metrics not found.")

        if self.__supports_table("NVTX_EVENTS"):
            self.events_supported.append(TraceLoaderEvents.NVTX_MARKERS)
        else:
            logger.warning("NVTX/Reflex information not found.")

        if self.__supports_table('ANALYSIS_DETAILS') or self.__supports_table('META_DATA_CAPTURE'):
            self.supported.append(TraceLoaderSupport.ANALYSIS_DURATION)

        if self.__supports_table("DX12_API"):
            self.events_supported.append(TraceLoaderEvents.DX12_API_CALLS)
        else:
            logger.warning("DX12 API information not found.")

        if self.__supports_table("CUPTI_ACTIVITY_KIND_RUNTIME"):
            self.events_supported.append(TraceLoaderEvents.CUDA_API_CALLS)
            self.regions_supported.append(TraceLoaderRegions.CUDNN_KERNEL_LAUNCHES)
        else:
            logger.warning("Cuda API information not found.")

        if self.__supports_table("CUPTI_ACTIVITY_KIND_KERNEL"):
            self.events_supported.append(TraceLoaderEvents.CUDA_GPU_KERNELS)
            self.regions_supported.append(TraceLoaderRegions.CUDNN_GPU_KERNELS)
        else:
            logger.warning("Cuda GPU kernel information not found.")

        if self.__supports_table("NVTX_EVENTS") and self.__supports_table("CUPTI_ACTIVITY_KIND_RUNTIME") and self.__supports_table("CUPTI_ACTIVITY_KIND_KERNEL"):
            self.events_supported.append(TraceLoaderEvents.NVTX_GPU_MARKERS)

        if self.__supports_table("MPI_OTHER_EVENTS"):
            self.events_supported.append(TraceLoaderEvents.MPI_MARKERS)
        else:
            logger.warning("MPI information not found.")

        if self.__supports_table('D3D12_PIX_DEBUG_API'):
            self.events_supported.append(TraceLoaderEvents.PIX_MARKERS)
        else:
            logger.warning("PIX markers information not found.")

        if self.__supports_table('DX12_WORKLOAD'):
            self.events_supported.append(TraceLoaderEvents.DX12_GPU_WORKLOAD)

        if self.__are_dxgkrnl_profile_ranges_available():
            self.events_supported.append(TraceLoaderEvents.DXGKRNL_PROFILE_RANGE)

        if self.__supports_table('ETW_EVENTS') and self.__supports_table('ETW_TASKS'):
            self.events_supported.append(TraceLoaderEvents.ETW_EVENTS)
        else:
            logger.warning("ETW event information not found.")

    ####################################################
    # Determine if this is a graphics workload
    ####################################################
    def is_graphics_workload(self,
                             start_time_ns: int,
                             end_time_ns: int,
                             target_pid: int):
        # Look for DX12 workload
        if self.__are_events_available_in_sql_table('DX12_WORKLOAD', start_time_ns, end_time_ns, target_pid):
            return True
        if self.__are_events_available_in_sql_table('DX12_API', start_time_ns, end_time_ns, target_pid):
            return True
        if self.__are_cpu_frame_times_available(start_time_ns, end_time_ns, target_pid):
            return True
        return False

    ####################################################
    # Determine if this is a compute workload
    ####################################################
    def is_compute_workload(self,
                            start_time_ns: int,
                            end_time_ns: int,
                            target_pid: int):
        # Look for CUDA workload
        if self.__are_events_available_in_sql_table('CUPTI_ACTIVITY_KIND_KERNEL', start_time_ns, end_time_ns, target_pid):
            return True
        if self.__are_events_available_in_sql_table('CUPTI_ACTIVITY_KIND_RUNTIME', start_time_ns, end_time_ns, target_pid):
            return True
        return False

    ####################################################
    # Initialise the thread name dictionary
    ####################################################
    def init_thread_name_dict(self) -> dict:
        """Initialise the thread name dictionary stored in super()."""
        cursor = self.__connection.cursor()
        thread_name_dict = {}
        if not self.__supports_table('ThreadNames'):
            logger.error("ThreadNames table not found, can't decode thread names")
            return thread_name_dict
        for row in self.__connection.execute('select nameId,globalTid,priority from ThreadNames'):
            cursor.execute('select id,value from StringIds where id=?', (str(row[0]),))
            thread_name = cursor.fetchone()
            thread_name_dict[row[1]] = thread_name[1]
        return thread_name_dict

    ####################################################
    # Initialise the process name dictionary
    ####################################################
    def init_process_name_dict(self) -> dict:
        """Initialise the process name dictionary stored in super()."""
        process_name_dict = {}
        if not self.__supports_table('PROCESSES'):
            logger.error("PROCESSES table not found, can't decode process names")
            return process_name_dict
        for row in self.__connection.execute('select globalPid, pid, name from PROCESSES'):
            process_name_dict[row['pid']] = row['name']
        return process_name_dict

    ####################################################
    # Get the call count of the system that recorded the trace, if available
    ####################################################
    def get_string(self, sid) -> str:
        if sid:
            if self.string_ids[sid]['id'] == sid:
                return self.string_ids[sid]['value']
            for s in filter(lambda s: s['id'] == sid, self.string_ids):
                return s['value']
        return "Unknown"

    def get_module_string(self, sid) -> str:
        return self.get_string(sid).split("\\")[-1].split("/")[-1]

    def get_string_id(self, string: str):
        for s in filter(lambda s: s['value'] == string, self.string_ids):
            return s['id']
        if self.verbose:
            logger.warning(f"Failed to match string: {string}")
        return -1

    ####################################################
    #
    # Private Interface - used by super
    #
    ####################################################
    ####################################################
    # Get timeslices
    ####################################################
    def _get_timeslices(self,
                        start_time_ns: Optional[float] = None,
                        end_time_ns: Optional[float] = None,
                        target_pid: Optional[int] = None,
                        quiet: Optional[bool] = False) -> List[TimeSlice]:

        if not self.__supports_table('SCHED_EVENTS'):
            raise RuntimeError('SQL DB does not contain SCHED_EVENTS which is required')

        timeslice_list = []

        cursor = self.__get_cursor()

        #
        # The SCHED_EVENTS track context switches of threads.
        # isSchedIn = True when swapped in, otherwise it is swapped out.
        #
        # This code will convert the timestamp events into timeslice ranges with a start/end for a given thread.
        # It does this by keeping a dictionary of timestamps per thread and using this to create the timeslice ranges
        #

        thread_dict = {}

        if start_time_ns is None and end_time_ns is None:
            cursor.execute('''select * from SCHED_EVENTS order by start''')
        else:  # assume both are valid
            cursor.execute('select * from SCHED_EVENTS WHERE start > ? and start < ? order by start', (str(start_time_ns), str(end_time_ns)))

        # NSys sometimes drops sched events.
        # If a swap out is immediately followed by a swap in, then the out event will be dropped.
        # In fact, it will drop events where the thread doesnt significantly change state. It also ignores
        # CPU execution cores, so can have threads starting on 1 CPU and ending on another!
        cpu_mismatch = 0
        missed_sched_out = 0
        missed_sched_in = 0
        results = cursor.fetchall()
        for row in tu.progressbar(list(results), "Processing Timeslices: ", 40, quiet):
            time = row['start']
            cpu = row['cpu']
            is_in = row['isSchedIn']
            gtid = row['globalTid']

            if target_pid is not None:
                pid, tid = tu.convert_global_tid(gtid)
                if pid != target_pid:
                    continue

            # If the gtid has not been seen, just store it and continue
            if gtid not in thread_dict:
                thread_dict[gtid] = (time, cpu, is_in)
                continue

            # retrieve the previous timeslice
            last_time, last_cpu, last_is_in = thread_dict[gtid]

            # Store the current timeslice
            thread_dict[gtid] = (time, cpu, is_in)

            # Have we missed a sched out event?
            if is_in and last_is_in:
                missed_sched_out += 1
                # Add a timeslice event assuming the previous timeslice ends now, then continue assuming a new timeslice
                timeslice_list.append(TimeSlice(last_time, time, last_cpu, gtid))

            # Have we missed a sched in event
            if not is_in and not last_is_in:
                missed_sched_in += 1
                # Skip this event
                continue

            if not is_in:
                # NSys doesn't check for CPU when dropping events (is_out immediately followed by is_in will get dropped)
                # So it is hard to track core affinity
                # We need to ignore CPU mismatches - which makes me sad :(
                if last_cpu != cpu:
                    cpu_mismatch += 1
                    #  continue  # Maybe one day, in a galaxy far far away, we can re-instate this check!

                # end of timeslice, append it
                timeslice_list.append(TimeSlice(last_time, time, cpu, gtid))

        if cpu_mismatch:
            logger.warning(f"WARNING: {cpu_mismatch} invalid CPU mismatches with timeslice start/end. Timeslices may be incorrect")
        if missed_sched_in:
            logger.warning(f"WARNING: {missed_sched_in} missed cswitch in events. Timeslices may be incorrect")
        if missed_sched_out:
            logger.warning(f"WARNING: {missed_sched_out} missed cswitch out events. Timeslices may be incorrect")

        return timeslice_list

    ####################################################
    # Get region durations
    ####################################################
    def _get_region_durations(self,
                              region_type: TraceLoaderRegions,
                              start_time_ns: Optional[float] = None,
                              end_time_ns: Optional[float] = None,
                              target_pid: Optional[int] = None) -> (float, list[FrameDurations]):
        assert region_type in self.regions_supported

        avg_duration = 0
        frame_list = None

        if region_type == TraceLoaderRegions.CPU_FRAMETIMES:
            avg_duration, frame_list = self.__get_cpu_frame_times(start_time_ns, end_time_ns, target_pid)
        if region_type == TraceLoaderRegions.CUDNN_KERNEL_LAUNCHES:
            avg_duration, frame_list = self.__get_cuda_api_frame_times("cuLaunchKernel", "cudnn_generated", start_time_ns, end_time_ns, target_pid)
        if region_type == TraceLoaderRegions.CUDNN_GPU_KERNELS:
            avg_duration, frame_list = self.__get_cuda_gpu_frame_times("cudnn_generated", start_time_ns, end_time_ns, target_pid)
        return avg_duration, frame_list

    def _get_derived_region_durations(self,
                                      region_type: TraceLoaderRegions,
                                      base_durations: list[FrameDurations],
                                      start_time_ns: Optional[float] = None,
                                      end_time_ns: Optional[float] = None,
                                      target_pid: Optional[int] = None,
                                      ) -> list[FrameDurations]:
        assert region_type in self.derived_regions_supported
        if region_type == TraceLoaderRegions.GPU_FRAMETIMES:
            return self.__get_gpu_frame_times(base_durations, start_time_ns=start_time_ns, end_time_ns=end_time_ns)
        return None

    ####################################################
    # Get CPU frametimes
    ####################################################
    def __get_cpu_frame_times(self,
                              start: Optional[float] = None,
                              end: Optional[float] = None,
                              target_pid: Optional[int] = None) -> (float, list[FrameDurations]):
        frame_list = []

        # Get the provider IDs from ETW
        preferred_provider = None
        preferred_opcode = None
        if self.__supports_table('ETW_PROVIDERS'):
            provider_id_dict = {}
            for row in self.__get_connection().execute('select * from ETW_PROVIDERS'):
                provider_id_dict[self.get_string(row["providerNameId"])] = row['providerId']

            # Choose the priority provider from the ordered list
            # DxgKrnl aligns with nsys and ignores DXGI's test present
            provider_priority_list = [
                # Provider, Opcode
                ("Microsoft-Windows-DxgKrnl", 0),
                ("Microsoft-Windows-DXGI", 1),
            ]
            for provider, opcode in provider_priority_list:
                if provider in provider_id_dict:
                    preferred_provider = provider_id_dict[provider]
                    preferred_opcode = opcode  # START opcode - valid for DXGI and DxgKrnl
                    break

            if self.verbose:
                logger.debug(f"preferred ETW provider for Present() {preferred_provider}")
        else:
            logger.warning("Doesn't support ETW_PROVIDERS.")

        # Several different ways to get the CPU frame times, all based on capturing present information
        # 1. ETW event task names - this is old and deprecated, but we keep it in case we see any old traces
        if self.__supports_table('ETW_EVENT_TASK_NAMES'):
            if self.verbose:
                logger.debug("ETW_EVENT_TASK_NAMES")
            # The old deprecated way
            if self.__supports_table('ETW_EVENTS_DEPRECATED_TABLE'):
                if self.verbose:
                    logger.debug("ETW_EVENTS_DEPRECATED_TABLE")
                rows = self.__get_connection().execute('select * from ETW_EVENTS_DEPRECATED_TABLE where taskName=? ORDER BY start ASC', (self.get_string_id('Present'),))
            else:
                rows = self.__get_connection().execute('select * from ETW_EVENTS where taskName=? ORDER BY start ASC', (self.get_string_id('Present'),))
                if self.verbose:
                    logger.debug("ETW_EVENTS")

            # Get the list of frames and frame start times
            for row in rows:
                gtid = row['globalTid']
                pid, tid = tu.convert_global_tid(gtid)
                if (target_pid is not None) and (target_pid != pid):
                    continue

                # Trim
                if start is not None and row['start'] < start:
                    continue
                if end is not None and row['end'] > end:
                    continue

                frame_list.append(Frames(row['start'], row['end']))

        # 2. Standard ETW events
        if not frame_list and self.__supports_table('ETW_TASKS'):
            if self.verbose:
                logger.debug("ETW_TASKS")

            # Need to get the right provider and use a consistent opcode (ie. mindful of start/end pairs)
            presentTaskID = -1
            for row in self.__get_connection().execute('select * from ETW_TASKS where taskNameId=?', (self.get_string_id('Present'),)):
                if not preferred_provider:
                    preferred_provider = row['providerId']

                if row['providerId'] != preferred_provider:
                    continue

                presentTaskID = row['taskId']

            if presentTaskID < 0:
                return 0, None

            # Get the list of frames and frame start times
            last_present = start
            query = 'select * from ' + self.__get_etw_events_sql_table_for_query() + ' where '
            query += self.__etw_task_id_column_name() + '=? AND '
            query += self.__etw_provider_id_column_name() + '=? ORDER BY timestamp ASC'
            for row in self.__get_connection().execute(query, (presentTaskID, preferred_provider)):
                pid, _ = self.__get_pid_tid_from_row(row)
                if (target_pid is not None) and (target_pid != pid):
                    continue

                # If we don't have a preferred opcode, choose the first one
                if not preferred_opcode:
                    preferred_opcode = row['opcode']

                # DXGI events have a START/END opcode - we only want the start
                if row['opcode'] != preferred_opcode:
                    continue

                # Ignore test presents
                data_dict = json.loads(row['data'])
                if 'Flags' in data_dict and 'DXGI_PRESENT_TEST' in data_dict['Flags']:  # Use 'str in flags' instead of == because have seen extra spaces at the end of the flag string
                    continue

                present_time = row['timestamp']

                # Trim
                if start is not None and present_time < start:
                    continue
                if end is not None and present_time > end:
                    continue

                if present_time < last_present:
                    logger.warning(f"Error in order of presents: {present_time} prev: {last_present}")

                frame_list.append(Frames(last_present, present_time))
                last_present = present_time

        # 3. Look in the WDDM queue packets for the present packet
        if not frame_list and self.__supports_table('WDDM_QUEUE_PACKET_START_EVENTS'):
            last_present = start
            for row in self.__get_connection().execute('SELECT start, end, globalTid, present FROM WDDM_QUEUE_PACKET_START_EVENTS WHERE present=1 ORDER BY start ASC'):
                pid, tid = tu.convert_global_tid(row['globalTid'])
                if (target_pid is not None) and (target_pid != pid):
                    continue

                present_time = row['start']
                # Trim
                if start is not None and present_time < start:
                    continue
                if end is not None and present_time > end:
                    continue

                if present_time < last_present:
                    logger.warning(f"Error in order of presents: {present_time} prev: {last_present}")

                frame_list.append(Frames(last_present, present_time))
                last_present = present_time

        if not frame_list:
            logger.error("No CPU frametimes detected")
            return 0, None

        # Create a list of frames with their start time and duration
        frame_durations_list = []
        average_frame_durations = 0

        if self.verbose:
            logger.debug(f"Framelist length: {len(frame_list)}")

        last = frame_list[0].start
        for f in frame_list:
            time = f.start - last  # This is the duration for the last frame

            # This should be offset by 1, so always appending the previous frame as
            # the first frame will have a zero time
            # (aprokopenko) the last frame wouldn't be handled in this case. isn't it?
            if time > 0:
                frame_durations_list.append(FrameDurations(last, time))
                average_frame_durations += time
            last = f.start
        average_frame_durations /= len(frame_durations_list)

        average_frame_durations = average_frame_durations / 1000000
        return average_frame_durations, frame_durations_list

    ####################################################
    # Helper to get GPU frametimes from CPU frametimes
    ####################################################
    def __get_gpu_frame_times(self,
                              base_durations: List[FrameDurations],
                              start_time_ns: float,
                              end_time_ns: float) -> List[FrameDurations]:

        frametime_list_gpu = []

        gpu_metric_list, metric_list = self._get_gpu_metric_frame_list(TraceLoaderGPUMetrics.GPU_UTILISATION, start_time_ns=start_time_ns, end_time_ns=end_time_ns)
        if not gpu_metric_list:
            return None

        metric_start_index = 0
        metric_end_index = len(gpu_metric_list)

        # Walk the CPU frames to find the matching metrics in each frame
        # The frames should be in ascending order
        for frame in base_durations:
            frame_end = frame.start + frame.duration

            if metric_start_index == metric_end_index:
                break

            # find the starting timestamp for this frame
            while gpu_metric_list[metric_start_index].timestamp < frame.start:
                metric_start_index += 1

            gr_active = 0.0
            count = 0

            # Walk the CPU frame finding the GR Active samples and accumulating them
            while gpu_metric_list[metric_start_index + count].timestamp <= frame_end:
                gr_active += gpu_metric_list[metric_start_index + count].value
                count += 1
                if (metric_start_index + count) == metric_end_index:
                    break

            metric_start_index += count

            if count:
                gr_active = gr_active / count

            # Always add a frame duration
            frametime_list_gpu.append(FrameDurations(frame.start, frame.duration * (gr_active * 0.01)))

        if len(frametime_list_gpu):
            return frametime_list_gpu
        return None

    ####################################################
    # Get the average GPU metrics
    ####################################################
    def _get_average_gpu_metrics(self,
                                 metric: TraceLoaderGPUMetrics,
                                 start_time_ns: Optional[float] = None,
                                 end_time_ns: Optional[float] = None) -> float:
        metric_name = self.get_gpu_metric_name(metric)
        if metric_name:
            val =  self.__get_average_gpu_metrics(metric_name, start_time_ns, end_time_ns)

            if val:
                if metric is TraceLoaderGPUMetrics.GPU_UTILISATION:
                    return val / 100
                return val
        return 0

    ####################################################
    # Get GPU metric list as frame durations
    ####################################################
    def _get_gpu_metric_frame_list(self,
                                   metric_type: TraceLoaderGPUMetrics,
                                   min_metric: Optional[float] = None,
                                   max_metric: Optional[float] = None,
                                   min_percent: Optional[float] = None,
                                   max_percent: Optional[float] = None,
                                   start_time_ns: Optional[float] = None,
                                   end_time_ns: Optional[float] = None) -> (List[GPUMetric], List[FrameDurations]):
        start_frame = start_time_ns
        end_frame = end_time_ns
        valid = False

        found_min = float('inf')
        found_max = 0

        gpu_metric_list = []
        duration_list = []
        metric = self.get_gpu_metric_name(metric_type)
        if not metric:
            return None, None

        if self.nsys_schema_version_major < 3 or (self.nsys_schema_version_major == 3 and self.nsys_schema_version_minor <= 14):
            event_source = -1
            event_type = -1

            if not (self.__supports_table('GENERIC_EVENT_SOURCES') and self.__supports_table('GENERIC_EVENT_TYPES') and self.__supports_table('GENERIC_EVENTS')):
                logger.warning("GENERIC_EVENT_SOURCES not supported.")
                return None, None

            # Look for the metric we care about in the SOURCES
            for row in self.__get_connection().execute('''SELECT sourceId, data FROM GENERIC_EVENT_SOURCES'''):
                data_dict = json.loads(row['data'])
                if data_dict["Name"] == "GpuMetrics":
                    event_source = row['sourceId']
                    break

            if event_source == -1:
                logger.warning(f"{metric} not supported.")
                return None, None

            # Double check it in the events
            for row in self.__get_connection().execute('SELECT sourceId, typeId, data FROM GENERIC_EVENT_TYPES where sourceId=?', (event_source,)):
                data_dict = json.loads(row['data'])
                fields = data_dict["Fields"]
                for f in fields:
                    if f["Name"] == metric:  # "GR Active":
                        # logger.debug("Found ", metric)
                        event_type = row['typeId']
                        break

            if event_type == -1:
                logger.warning(f"{metric} not supported.")
                return None, None

            gpu_metric_list = []
            if start_frame and end_frame:
                for row in self.__get_connection().execute('SELECT rawTimestamp, data FROM GENERIC_EVENTS where typeId=? and rawTimestamp >= ? and rawTimestamp < ?', (event_type, start_time_ns, end_time_ns)):
                    data_dict = json.loads(row['data'])
                    if metric in data_dict:
                        val = int(data_dict[metric])
                        timestamp = int(row['rawTimestamp'])
                        gpu_metric_list.append(GPUMetric(timestamp, val))
                        found_min = min(found_min, val)
                        found_max = max(found_max, val)
            else:
                for row in self.__get_connection().execute('SELECT rawTimestamp, data FROM GENERIC_EVENTS where typeId=?', (event_type,)):
                    data_dict = json.loads(row['data'])
                    if metric in data_dict:
                        val = int(data_dict[metric])
                        timestamp = int(row['rawTimestamp'])
                        gpu_metric_list.append(GPUMetric(timestamp, val))
                        found_min = min(found_min, val)
                        found_max = max(found_max, val)
        else:
            if not (self.__supports_table('GPU_METRICS') and self.__supports_table('TARGET_INFO_GPU_METRICS')):
                logger.warning("GPU_METRICS not supported.")
                return None, None

            if start_frame and end_frame:
                for row in self.__get_connection().execute('SELECT * FROM GPU_METRICS JOIN TARGET_INFO_GPU_METRICS on GPU_METRICS.metricId=TARGET_INFO_GPU_METRICS.metricId where metricName=? and rawTimestamp >= ? and rawTimestamp < ?', (metric, start_time_ns, end_time_ns)):
                    timestamp = int(row['rawTimestamp'])
                    val = int(row['value'])
                    gpu_metric_list.append(GPUMetric(timestamp, val))
                    found_min = min(found_min, val)
                    found_max = max(found_max, val)
            else:
                for row in self.__get_connection().execute('SELECT * FROM GPU_METRICS JOIN TARGET_INFO_GPU_METRICS on GPU_METRICS.metricId=TARGET_INFO_GPU_METRICS.metricId where metricName=?', (metric)):
                    timestamp = int(row['rawTimestamp'])
                    val = int(row['value'])
                    gpu_metric_list.append(GPUMetric(timestamp, val))
                    found_min = min(found_min, val)
                    found_max = max(found_max, val)

        if len(gpu_metric_list) == 0:
            logger.warning(f"{metric} data not found in events.")
            return None, None
            # logger.debug(f"found_min = {found_min}, found_max = {found_max}")

        if min_percent is not None:
            min_metric = found_max * (min_percent / 100)
        if max_percent is not None:
            max_metric = found_max * (max_percent / 100)

        region_min = found_max
        region_max = found_min
        region_total = 0

        for metric_sample in gpu_metric_list:
            val = metric_sample.value
            timestamp = metric_sample.timestamp

#            logger.debug(val, min_metric, max_metric)
            # if valid
            ok = True
            if min_metric is not None and val < min_metric:
                ok = False
            if max_metric is not None and val > max_metric:
                ok = False

            if ok:
                region_total += val
                region_min = min(region_min, val)
                region_max = max(region_max, val)

                if not valid:
                    valid = True
                    start_frame = timestamp
                end_frame = timestamp
                continue
            # else if it was valid, report it
            duration = end_frame - start_frame
#            if duration == 0:
#                duration = 1
            if valid:
                duration_list.append(FrameDurations(start=start_frame, duration=duration, min_value=region_min, max_value=region_max, total_value=region_total))
            region_min = found_max
            region_max = found_min
            region_total = 0
            valid = False

        if valid:
            duration_list.append(FrameDurations(start=start_frame, duration=end_frame - start_frame, min_value=region_min, max_value=region_max, total_value=region_total))

        return gpu_metric_list, duration_list

    ####################################################
    # Find average GPU metrics
    ####################################################
    def __get_average_gpu_metrics(self,
                                  metric: str,
                                  start: Optional[float] = None,
                                  end: Optional[float] = None) -> Optional[float]:
        event_source = -1
        event_type = -1
        metric_list = []

        if self.nsys_schema_version_major < 3 or (self.nsys_schema_version_major == 3 and self.nsys_schema_version_minor <= 14):
            if not (self.__supports_table('GENERIC_EVENT_SOURCES') and self.__supports_table('GENERIC_EVENT_TYPES') and self.__supports_table('GENERIC_EVENTS')):
                logger.warning("GENERIC_EVENT_SOURCES not supported.")
                return None

            for row in self.__get_connection().execute('''SELECT sourceId, data FROM GENERIC_EVENT_SOURCES'''):
                data_dict = json.loads(row['data'])
                if data_dict["Name"] == "GpuMetrics":
                    event_source = row['sourceId']

            if event_source == -1:
                return None

            for row in self.__get_connection().execute('SELECT sourceId, typeId, data FROM GENERIC_EVENT_TYPES where sourceId=?', (event_source,)):
                data_dict = json.loads(row['data'])
                fields = data_dict["Fields"]
                for f in fields:
                    if f["Name"] == metric:  # "GR Active":
                        event_type = row['typeId']

            if event_type == -1:
                logger.warning(f"{metric} not supported.")
                return None

            for row in self.__get_connection().execute('SELECT rawTimestamp, data FROM GENERIC_EVENTS where typeId=?', (event_type,)):
                data_dict = json.loads(row['data'])

                # Trim
                if start is not None and row['rawTimestamp'] < start:
                    continue
                if end is not None and row['rawTimestamp'] > end:
                    continue

                # There may be bad samples, or no samples
                if metric in data_dict:
                    metric_list.append(int(data_dict[metric]))
        else:
            if not (self.__supports_table('GPU_METRICS') and self.__supports_table('TARGET_INFO_GPU_METRICS')):
                logger.warning("GPU_METRICS not supported.")
                return None, None

            for row in self.__get_connection().execute('SELECT * FROM GPU_METRICS JOIN TARGET_INFO_GPU_METRICS on GPU_METRICS.metricId=TARGET_INFO_GPU_METRICS.metricId where metricName=? and rawTimestamp >= ? and rawTimestamp < ?', (metric, start, end)):
                # Trim
                if start is not None and row['rawTimestamp'] < start:
                    continue
                if end is not None and row['rawTimestamp'] > end:
                    continue

                metric_list.append(int(row['value']))

        average_metric = -1
        if len(metric_list):
            average_metric = sum(metric_list) / len(metric_list)

        return average_metric

    ####################################################
    # Find average GPU metrics
    ####################################################
    def _get_all_average_gpu_metrics(self,
                                     start_time_ns: Optional[float] = None,
                                     end_time_ns: Optional[float] = None) -> (dict, int):
        event_source = -1
        event_type = -1
        average_gpu_metric_dict = {}

        if self.nsys_schema_version_major < 3 or (self.nsys_schema_version_major == 3 and self.nsys_schema_version_minor <= 14):
            if not (self.__supports_table('GENERIC_EVENT_SOURCES') and self.__supports_table('GENERIC_EVENT_TYPES') and self.__supports_table('GENERIC_EVENTS')):
                logger.warning("GENERIC_EVENT_SOURCES not supported.")
                return None, 0

            gpu_metric_dict = {}
            sample_count = 0

            for row in self.__get_connection().execute('''SELECT sourceId, data FROM GENERIC_EVENT_SOURCES'''):
                data_dict = json.loads(row['data'])
                if data_dict["Name"] == "GpuMetrics":
                    event_source = row['sourceId']

            if event_source == -1:
                return None, 0

            for row in self.__get_connection().execute('SELECT sourceId, typeId, data FROM GENERIC_EVENT_TYPES where sourceId=?', (event_source,)):
                event_type = row['typeId']
                break

            for row in self.__get_connection().execute('SELECT rawTimestamp, data FROM GENERIC_EVENTS where typeId=?', (event_type,)):
                # Trim
                if start_time_ns is not None and row['rawTimestamp'] < start_time_ns:
                    continue
                if end_time_ns is not None and row['rawTimestamp'] > end_time_ns:
                    continue

                data_dict = json.loads(row['data'])

                for m in data_dict:
                    if isinstance(data_dict[m], dict):
                        for m2 in data_dict[m]:
                            gpu_metric_dict[m2] = gpu_metric_dict.get(m2, 0) + float(data_dict[m][m2])
                    else:
                        gpu_metric_dict[m] = gpu_metric_dict.get(m, 0) + float(data_dict[m])

                sample_count += 1

            if sample_count:
                for m, value in gpu_metric_dict.items():
                    average_gpu_metric_dict[m] = value / sample_count

        else:
            if not (self.__supports_table('GPU_METRICS') and self.__supports_table('TARGET_INFO_GPU_METRICS')):
                logger.warning("GPU_METRICS not supported.")
                return None, None

            # Get the list of metric names
            metric_names = []
            for row in self.__get_connection().execute('SELECT * FROM TARGET_INFO_GPU_METRICS'):
                metric_names.append(row['metricName'])

            if len(metric_names) == 0:
                logger.warning("No metrics found.")
                return None, None

            # Need an average for each metric
            gpu_metric_dict = {}
            for row in self.__get_connection().execute('SELECT * FROM GPU_METRICS JOIN TARGET_INFO_GPU_METRICS on GPU_METRICS.metricId=TARGET_INFO_GPU_METRICS.metricId where rawTimestamp >= ? and rawTimestamp < ?', (start_time_ns, end_time_ns)):
                if row['metricName'] in gpu_metric_dict:
                    gpu_metric_dict[row['metricName']].append(row['value'])
                else:
                    gpu_metric_dict[row['metricName']] = [row['value']]

            sample_count_list = []
            for key, value in gpu_metric_dict.items():
                average_gpu_metric_dict[key] = float(sum(value)) / float(len(value))
                sample_count_list.append(len(value))

            sample_count = max(sample_count_list) if len(sample_count_list) else 0

        return average_gpu_metric_dict, sample_count

    ####################################################
    # Get the callstacks
    ####################################################
    def _get_callstacks(self,
                        start_time_ns: float,
                        end_time_ns: float,
                        target_pid: int,
                        target_tid: Optional[int]) -> List[CallStack]:

        query = '''SELECT * FROM SAMPLING_CALLCHAINS
        LEFT JOIN COMPOSITE_EVENTS ON SAMPLING_CALLCHAINS.id == COMPOSITE_EVENTS.id
        WHERE start >= ? and start < ?
        ORDER BY start ASC'''

        # Get the callstacks
        callstack_dict = {}
        for row in self.__get_connection().execute(query, (str(start_time_ns), str(end_time_ns))):
            pid, tid = tu.convert_global_tid(row['globalTid'])
            if target_pid is not None and pid != target_pid:
                continue
            if target_tid is not None and tid != target_tid:
                continue
            time = row['start']
            cs_id = row['id']
            function = row['symbol']
            module = row['module']
            depth = row['stackDepth']
            stack_type = CallStackType.EVENT

            if row['cpucycles'] == 1:
                stack_type = CallStackType.SAMPLED

            if cs_id not in callstack_dict:
                callstack_dict[cs_id] = CallStack(id=cs_id, tid=tid, pid=pid, time=time, stack_type=stack_type)

            callstack_dict[cs_id].stack.append(CallStackFrame(function=function, module=module, depth=depth))

        # Convert dictionary to list and sort the stack frames based on depth
        callstack_sample_list = list(callstack_dict.values())
        for cs in callstack_sample_list:
            # NSys has depth=0 as the 'top' of the stack, so reverse it.
            # depth=0 is the entry point.
            cs.stack.sort(key=lambda stack_frame: stack_frame.depth, reverse=True)
            for ii, stack_frame in enumerate(cs.stack):
                cs.stack[ii].depth = ii

        return callstack_sample_list

    ####################################################
    # Get the duration from DB
    ####################################################
    def _get_analysis_duration(self) -> Optional[float]:
        if self.__supports_table('ANALYSIS_DETAILS'):
            for row in self.__get_connection().execute('SELECT duration FROM ANALYSIS_DETAILS'):
                return row['duration']
        if self.__supports_table('META_DATA_CAPTURE'):
            for row in self.__get_connection().execute('SELECT RUN_DURATION_MS FROM META_DATA_CAPTURE'):
                return row['RUN_DURATION_MS']
        return None

    ####################################################
    # Get the core counts from DB
    ####################################################
    def _get_core_count(self) -> int:
        cpu_core_count = 0
        if self.__supports_table('TARGET_INFO_SYSTEM_ENV'):
            cursor = self.__get_cursor()

            # Get env info
            for row in cursor.execute('select name,value from TARGET_INFO_SYSTEM_ENV'):
                if row['name'] == "CpuCores":
                    cpu_core_count = int(row[1])
                    break
        else:
            logger.error("TARGET_INFO_SYSTEM_ENV table not found, CPU core count not available")
        return cpu_core_count

    ####################################################
    # Is the CPU Config supported
    ####################################################
    def __supports_cpu_config(self) -> bool:
        if self.__supports_table('TARGET_INFO_SYSTEM_ENV'):
            cursor = self.__get_cursor()

            # Get env info
            for row in cursor.execute('select name,value from TARGET_INFO_SYSTEM_ENV'):
                if row['name'] == "CpuInfo":
                    data_dict = json.loads(row[1])
                    if len(data_dict['items']) > 0 and 'EfficiencyType' in data_dict['items'][0]:
                        return True
                    return False
        return False

    ####################################################
    # Get the detailed CPU config from DB
    ####################################################
    def _get_cpu_config(self) -> tu.CPUConfig:
        if self.__supports_table('TARGET_INFO_SYSTEM_ENV'):
            cursor = self.__get_cursor()

            # Get env info
            for row in cursor.execute('select name,value from TARGET_INFO_SYSTEM_ENV'):
                if row['name'] == "CpuInfo":
                    data_dict = json.loads(row[1])
                    if len(data_dict['items']) > 0 and 'EfficiencyType' in data_dict['items'][0]:
                        cpu_config = tu.CPUConfig
                        # SMT/HT detection currently not supported
                        cpu_config.physical_e_core_count = sum(1 for item in data_dict['items'] if item['EfficiencyType'] == "E")
                        cpu_config.logical_e_core_count = cpu_config.physical_e_core_count
                        cpu_config.physical_p_core_count = sum(1 for item in data_dict['items'] if item['EfficiencyType'] == "P")
                        cpu_config.logical_p_core_count = cpu_config.physical_p_core_count
                        cpu_config.p_core_starting_index = min(item['cpuid'] for item in data_dict['items'] if item['EfficiencyType'] == "P")

                        return cpu_config

        logger.error("TARGET_INFO_SYSTEM_ENV table not found or CpuInfo not found, CPU core config not available.")
        return None

    ####################################################
    # Get events for the given time range and tid/pid
    ####################################################
    # pylint: disable=too-many-return-statements
    def _get_events(self,
                    event_type: TraceLoaderEvents,
                    start_time_ns: int,
                    end_time_ns: int,
                    target_pid: int,
                    target_tid: int) -> (dict, dict):
        """Get the events for the given time range."""
        duration_dict = None
        count_dict = None
        if event_type is TraceLoaderEvents.DX12_API_CALLS:
            duration_dict, count_dict = self.__get_events_from_sql_table(sql_table="DX12_API",
                                                                         nameId_column_name="nameId",
                                                                         lookup_nameId_in_string_table=True,
                                                                         start_time_ns=start_time_ns,
                                                                         end_time_ns=end_time_ns,
                                                                         target_pid=target_pid,
                                                                         target_tid=target_tid)

        if event_type is TraceLoaderEvents.CUDA_API_CALLS:
            duration_dict, count_dict = self.__get_events_from_sql_table(sql_table="CUPTI_ACTIVITY_KIND_RUNTIME",
                                                                         nameId_column_name="nameId",
                                                                         lookup_nameId_in_string_table=True,
                                                                         start_time_ns=start_time_ns,
                                                                         end_time_ns=end_time_ns,
                                                                         target_pid=target_pid,
                                                                         target_tid=target_tid)

        if event_type is TraceLoaderEvents.CUDA_GPU_KERNELS:
            duration_dict, count_dict = self.__get_events_from_sql_table(sql_table="CUPTI_ACTIVITY_KIND_KERNEL",
                                                                         nameId_column_name="demangledName",
                                                                         lookup_nameId_in_string_table=True,
                                                                         start_time_ns=start_time_ns,
                                                                         end_time_ns=end_time_ns,
                                                                         target_pid=target_pid,
                                                                         target_tid=None)

        if event_type is TraceLoaderEvents.NVTX_MARKERS:
            dur_a, count_a = self.__get_events_from_sql_table(sql_table="NVTX_EVENTS",
                                                              nameId_column_name="text",
                                                              lookup_nameId_in_string_table=False,
                                                              start_time_ns=start_time_ns,
                                                              end_time_ns=end_time_ns,
                                                              target_pid=target_pid,
                                                              target_tid=target_tid)
            dur_b, count_b = self.__get_events_from_sql_table(sql_table="NVTX_EVENTS",
                                                              nameId_column_name="textId",
                                                              lookup_nameId_in_string_table=True,
                                                              start_time_ns=start_time_ns,
                                                              end_time_ns=end_time_ns,
                                                              target_pid=target_pid,
                                                              target_tid=target_tid)
            duration_dict = dur_a | dur_b
            count_dict = count_a | count_b

        if event_type is TraceLoaderEvents.MPI_MARKERS:
            duration_dict, count_dict = self.__get_events_from_sql_table(sql_table="MPI_OTHER_EVENTS",
                                                                         nameId_column_name="textId",
                                                                         lookup_nameId_in_string_table=True,
                                                                         start_time_ns=start_time_ns,
                                                                         end_time_ns=end_time_ns,
                                                                         target_pid=target_pid,
                                                                         target_tid=target_tid)

        if event_type is TraceLoaderEvents.PIX_MARKERS:
            duration_dict, count_dict = self.__get_pix_markers(start_time_ns=start_time_ns,
                                                               end_time_ns=end_time_ns,
                                                               target_pid=target_pid,
                                                               target_tid=target_tid)

        if event_type is TraceLoaderEvents.DX12_GPU_WORKLOAD:
            duration_dict, count_dict = self.__get_events_from_sql_table(sql_table="DX12_WORKLOAD",
                                                                         nameId_column_name="textId",
                                                                         lookup_nameId_in_string_table=True,
                                                                         start_time_ns=start_time_ns,
                                                                         end_time_ns=end_time_ns,
                                                                         target_pid=target_pid,
                                                                         target_tid=target_tid)

        if event_type is TraceLoaderEvents.DXGKRNL_PROFILE_RANGE:
            duration_dict, count_dict = self.__get_dxgkrnl_profile_ranges(start_time_ns=start_time_ns,
                                                                          end_time_ns=end_time_ns,
                                                                          target_pid=target_pid,
                                                                          target_tid=target_tid)

        if event_type is TraceLoaderEvents.ETW_EVENTS:
            duration_dict, count_dict = self.__get_etw_events(start_time_ns=start_time_ns,
                                                              end_time_ns=end_time_ns,
                                                              target_pid=target_pid,
                                                              target_tid=target_tid)
        return duration_dict, count_dict

    ####################################################
    # Get events for the given time range and tid/pid
    ####################################################
    def _get_ordered_events(self,
                            event_type: TraceLoaderEvents,
                            start_time_ns: int,
                            end_time_ns: int,
                            target_pid: int,
                            target_tid: int) -> List:
        """Get the events for the given time range."""
        ordered_list = None
        if event_type is TraceLoaderEvents.NVTX_MARKERS:
            ordered_list = self.__get_ordered_events_from_sql_table(sql_table="NVTX_EVENTS",
                                                                    name_column_name="text",
                                                                    name_id_column_name="textId",
                                                                    start_time_ns=start_time_ns,
                                                                    end_time_ns=end_time_ns,
                                                                    target_pid=target_pid,
                                                                    target_tid=target_tid)
        if event_type is TraceLoaderEvents.NVTX_GPU_MARKERS:
            ordered_list = self.__get_ordered_nvtx_projected_gpu_events_from_gpu_time(start_time_ns=start_time_ns,
                                                                                      end_time_ns=end_time_ns,
                                                                                      target_pid=target_pid,
                                                                                      target_tid=target_tid)
        if event_type is TraceLoaderEvents.PIX_MARKERS:
            ordered_list = self.__get_ordered_pix_markers(target_pid=target_pid,
                                                          target_tid=target_tid,
                                                          start_time_ns=start_time_ns,
                                                          end_time_ns=end_time_ns)

        if event_type is TraceLoaderEvents.DX12_GPU_WORKLOAD:
            ordered_list = self.__get_ordered_events_from_sql_table(sql_table="DX12_WORKLOAD",
                                                                    name_column_name=None,
                                                                    name_id_column_name="textId",
                                                                    start_time_ns=start_time_ns,
                                                                    end_time_ns=end_time_ns,
                                                                    target_pid=target_pid,
                                                                    target_tid=target_tid)
        if event_type is TraceLoaderEvents.CUDA_GPU_KERNELS:
            ordered_list = self.__get_ordered_events_from_sql_table(sql_table="CUPTI_ACTIVITY_KIND_KERNEL",
                                                                    name_column_name=None,
                                                                    name_id_column_name="demangledName",
                                                                    start_time_ns=start_time_ns,
                                                                    end_time_ns=end_time_ns,
                                                                    target_pid=target_pid,
                                                                    target_tid=None)
        return ordered_list

    ####################################################
    #
    # Private functions to handle different events and regions
    #
    ####################################################
    def __get_nvtx_frame_times(self,
                               marker_name_starts_with,
                               domain_name: Optional[str] = None,
                               start_time_ns: Optional[float] = None,
                               end_time_ns: Optional[float] = None,
                               target_pid: Optional[int] = None) -> (float, list[FrameDurations]):
        """Get the NVTX marker frametimes based on the input strings"""
        if not self.__supports_table("NVTX_EVENTS"):
            logger.warning("NVTX_EVENTS sql table not supported")
            return None

        frametime_list_nvtx = []
        full_duration = end_time_ns - start_time_ns

        # Get the events
        query = 'SELECT * FROM NVTX_EVENTS WHERE end >= ? and start < ? ORDER BY start ASC'

        average_duration = 0
        for row in self.__get_connection().execute(query, (str(start_time_ns), str(end_time_ns))):
            pid, tid = tu.convert_global_tid(row['globalTid'])
            if target_pid is not None and pid != target_pid:
                continue

            domain_id = row['domainId']
            if domain_name and self.nvtx_domain_dict and domain_id in self.nvtx_domain_dict and domain_name != self.nvtx_domain_dict[domain_id]:
                continue

            name_id = row['textId']
            if name_id:
                name = self.get_string(name_id)
            else:
                name = str(row['text'])
            if not name.startswith(marker_name_starts_with):
                continue

            start = row['start']
            duration = row['end'] - start

            # Some markers can get messed up and have a marker which is as long as the trace!
            if duration > full_duration:
                continue

            # The markers in general can sometimes get messed up - put in a soft limit!
            if duration > 1000000000:  # 1s
                continue

            average_duration += duration
            frametime_list_nvtx.append(FrameDurations(start, duration))

        if len(frametime_list_nvtx):
            average_duration /= len(frametime_list_nvtx)
            return average_duration, frametime_list_nvtx
        return 0, None

    ####################################################
    # Helper to extract NVTX info from markers
    #
    # This is reasonably complex.
    # 1. Find the NVRTX markers we care about
    # 2. For each marker, find the associated CUDA kernel launches
    # 3. Track thosekernel launches on the GPU time line
    ####################################################
    def __get_nvtx_gpu_projected_frame_times(self,
                                             marker_name_starts_with,
                                             domain_name: Optional[str] = None,
                                             start_time_ns: Optional[float] = None,
                                             end_time_ns: Optional[float] = None,
                                             target_pid: Optional[int] = None) -> (float, list[FrameDurations]):
        """Get the NVTX marker frametimes based on the input strings"""
        if not self.__supports_table("NVTX_EVENTS"):
            logger.warning("NVTX_EVENTS sql table not supported")
            return None
        if not self.__supports_table("CUPTI_ACTIVITY_KIND_RUNTIME"):
            logger.warning("CUPTI_ACTIVITY_KIND_RUNTIME sql table not supported")
            return None
        if not self.__supports_table("CUPTI_ACTIVITY_KIND_KERNEL"):
            logger.warning("CUPTI_ACTIVITY_KIND_KERNEL sql table not supported")
            return None

        nvtx_marker_list = []

        # 1. Find the NVRTX markers we care about

        # Get the events
        query = 'SELECT * FROM NVTX_EVENTS WHERE end >= ? and start < ? ORDER BY start ASC'

        average_duration = 0
        marker_count = 0
        for row in self.__get_connection().execute(query, (str(start_time_ns), str(end_time_ns))):
            pid, tid = tu.convert_global_tid(row['globalTid'])
            if target_pid is not None and pid != target_pid:
                continue

            domain_id = row['domainId']
            if domain_name and self.nvtx_domain_dict and domain_id in self.nvtx_domain_dict and domain_name != self.nvtx_domain_dict[domain_id]:
                continue

            name_id = row['textId']
            if name_id:
                name = self.get_string(name_id)
            else:
                name = str(row['text'])
            if not name.startswith(marker_name_starts_with):
                continue

            # Use the marker count as some form of ID
            nvtx_marker_list.append((name, row['start'], row['end'], marker_count, pid, tid))
            marker_count += 1

        # 2. We have the framedurations, now walk the durations finding the cuda launch kernels
        query = 'SELECT * FROM CUPTI_ACTIVITY_KIND_RUNTIME WHERE (end >= ? and start < ?) or (start >= ? and start <= ?) ORDER BY start ASC'

        marker_index = 0
        correlation_dict = {}
        for row in self.__get_connection().execute(query, (str(start_time_ns), str(end_time_ns), str(start_time_ns), str(end_time_ns))):
            pid = None
            tid = None
            if 'globalTid' in row.keys():
                pid, tid = tu.convert_global_tid(row['globalTid'])
            elif 'globalPid' in row.keys():
                pid, _ = tu.convert_global_tid(row['globalPid'])
            if target_pid is not None and pid != target_pid:
                continue

            nvtx_name, nvtx_start, nvtx_end, nvtx_correlation_id, nvtx_pid, nvtx_tid = nvtx_marker_list[marker_index]

            # Iterate until we have data for the first frame
            if row['start'] < nvtx_start:
                continue

            # Iterate the frame index
            while (marker_index + 1) < len(nvtx_marker_list) and row['start'] > nvtx_end:
                marker_index += 1
                nvtx_name, nvtx_start, nvtx_end, nvtx_correlation_id, nvtx_pid, nvtx_tid = nvtx_marker_list[marker_index]

            # If this row still starts AFTER the nvtx_end, then we have processed all relevant events so break
            if row['start'] > nvtx_end:
                break

            if pid != nvtx_pid:
                continue

            if tid != nvtx_tid:
                continue

            # this row should be relevant, so track the correlations, mapping the CUDA API to NVTX kernel
            correlation_dict[row['correlationId']] = nvtx_correlation_id

        # 3. Now find the matching GPU times for the correlation lists
        marker_index = 0
        nvtx_gpu_regions = {}
        query = 'SELECT * FROM CUPTI_ACTIVITY_KIND_KERNEL WHERE (end >= ? and start < ?) or (start >= ? and start <= ?) ORDER BY start ASC'
        for row in self.__get_connection().execute(query, (str(start_time_ns), str(end_time_ns), str(start_time_ns), str(end_time_ns))):
            pid = None
            tid = None
            if 'globalTid' in row.keys():
                pid, tid = tu.convert_global_tid(row['globalTid'])
            elif 'globalPid' in row.keys():
                pid, _ = tu.convert_global_tid(row['globalPid'])
            if target_pid is not None and pid != target_pid:
                continue

            # get the nvtx region this contributes to
            correlation_id = row['correlationId']
            if correlation_id not in correlation_dict:
                continue
            nvtx_correlation_id = correlation_dict[correlation_id]
            nvtx_gpu_start, nvtx_gpu_end = nvtx_gpu_regions.get(nvtx_correlation_id, (sys.maxsize, -sys.maxsize))

            start = row['start']
            end = row['end']
            nvtx_gpu_start = min(nvtx_gpu_start, start)
            nvtx_gpu_end = max(nvtx_gpu_end, end)

            nvtx_gpu_regions[nvtx_correlation_id] = (nvtx_gpu_start, nvtx_gpu_end)

        # 4. Walk the original frames creating the GPU frametimes
        average_duration = 0
        nvtx_gpu_duration_list = []
        for nvtx_name, nvtx_start, nvtx_end, nvtx_correlation_id, nvtx_pid, nvtx_tid in nvtx_marker_list:
            if nvtx_correlation_id not in nvtx_gpu_regions:
                continue

            nvtx_gpu_start, nvtx_gpu_end = nvtx_gpu_regions[nvtx_correlation_id]

            duration = nvtx_gpu_end - nvtx_gpu_start

            average_duration += duration
            nvtx_gpu_duration_list.append(FrameDurations(nvtx_gpu_start, duration))

        if len(nvtx_gpu_duration_list):
            average_duration /= len(nvtx_gpu_duration_list)
            return average_duration, nvtx_gpu_duration_list
        return 0, None

    ####################################################
    # Helper to extract projected GPU times
    ####################################################
    def __get_ordered_nvtx_projected_gpu_events_from_gpu_time(self,
                                                              start_time_ns: int,
                                                              end_time_ns: int,
                                                              target_pid: int,
                                                              target_tid: int) -> List:
        if not self.__supports_table("NVTX_EVENTS"):
            logger.warning("NVTX_EVENTS sql table not supported")
            return None
        if not self.__supports_table("CUPTI_ACTIVITY_KIND_RUNTIME"):
            logger.warning("CUPTI_ACTIVITY_KIND_RUNTIME sql table not supported")
            return None
        if not self.__supports_table("CUPTI_ACTIVITY_KIND_KERNEL"):
            logger.warning("CUPTI_ACTIVITY_KIND_KERNEL sql table not supported")
            return None

        # 1. Walk the GPU timeline looking for CUDA kernel execution
        cuda_gpu_kernel_list = []
        cuda_gpu_kernel_dict = {}
        query = 'SELECT * FROM CUPTI_ACTIVITY_KIND_KERNEL WHERE (end >= ? and start < ?) or (start >= ? and start <= ?) ORDER BY start ASC'
        for row in self.__get_connection().execute(query, (str(start_time_ns), str(end_time_ns), str(start_time_ns), str(end_time_ns))):
            pid = None
            tid = None
            if 'globalTid' in row.keys():
                pid, tid = tu.convert_global_tid(row['globalTid'])
            elif 'globalPid' in row.keys():
                pid, _ = tu.convert_global_tid(row['globalPid'])
            if target_pid is not None and pid != target_pid:
                continue

            # get the nvtx region this contributes to
            correlation_id = row['correlationId']
            start = row['start']
            end = row['end']
            name = self.get_string(row['demangledName'])
            cuda_gpu_kernel_list.append((start, end, correlation_id, name))
            cuda_gpu_kernel_dict[correlation_id] = (start, end)

        # 2. For each GPU time period, need to find the matching CUDA API calls
        cuda_cpu_kernel_list = []
        cpu_start = sys.maxsize
        cpu_end = -sys.maxsize
        query = 'SELECT * FROM CUPTI_ACTIVITY_KIND_RUNTIME WHERE correlationId == ? ORDER BY start ASC'
        for gpu_start, gpu_end, gpu_correlation_id, gpu_kernel_name in cuda_gpu_kernel_list:
            for row in self.__get_connection().execute(query, (str(gpu_correlation_id),)):
                correlation_id = row['correlationId']
                start = row['start']
                end = row['end']

                cuda_cpu_kernel_list.append((start, end, correlation_id, gpu_kernel_name))
                cpu_start = start if start < cpu_start else cpu_start
                cpu_end = end if end > cpu_end else cpu_end

        # 3. For each NVTX range, find which CUDA cpu kernels were launched and then find the gpu times
        duration = cpu_end - cpu_start  # Create a reasonable buffer
        query = 'SELECT * FROM NVTX_EVENTS WHERE end >= ? and start < ? ORDER BY start ASC'
        marker_duration_list = []
        for row in self.__get_connection().execute(query, (str(cpu_start - duration), str(cpu_end + duration))):
            start = row['start']
            end = row['end']
            json_id = row['jsonTextId']
            name_id = row['textId']
            onnx_layer = None
            if json_id:
                json_string = self.get_string(json_id)
                if json_string:
                    json_string = json_string.replace('\x1f', ' ')
                    json_dict = json.loads(json_string)
                    if 'Layers' in json_dict:
                        onnx_layer = json_dict['Layers']
                    elif 'Name' in json_dict:
                        onnx_layer = json_dict['Name']
            if name_id:
                name = self.get_string(name_id)
            else:
                name = str(row['text'])
            if onnx_layer:
                name = onnx_layer

            if end < cpu_start:
                continue

            if start > cpu_end:
                continue

            gpu_start = sys.maxsize
            gpu_end = -sys.maxsize

            # Find any cuda launches within this region
            match_count = 0
            tooltip = ""
            for cuda_cpu_start, cuda_cpu_end, cuda_cpu_correlation_id, cuda_kernel_name in cuda_cpu_kernel_list:
                if cuda_cpu_start >= start and cuda_cpu_end <= end:
                    if cuda_cpu_correlation_id not in cuda_gpu_kernel_dict:
                        continue
                    cuda_gpu_start, cuda_gpu_end = cuda_gpu_kernel_dict[cuda_cpu_correlation_id]
                    gpu_start = cuda_gpu_start if cuda_gpu_start < gpu_start else gpu_start
                    gpu_end = cuda_gpu_end if cuda_gpu_end > gpu_end else gpu_end
                    if match_count < 2:
                        if tooltip:
                            tooltip += ", "
                        tooltip += cuda_kernel_name
                    match_count += 1

            if match_count:
                if match_count > 2:
                    tooltip = f"{match_count} kernels."
                marker_duration_list.append((name, gpu_start, gpu_end - gpu_start, tooltip))
        return marker_duration_list

    ####################################################
    #
    # Get markers from a table for the given time period and tid/pid
    #
    ####################################################
    def __get_events_from_sql_table(self,
                                    sql_table: str,
                                    nameId_column_name: str,
                                    lookup_nameId_in_string_table: bool,
                                    start_time_ns: int,
                                    end_time_ns: int,
                                    target_pid: int,
                                    target_tid: int) -> (dict, dict):
        """Extract events from the provided SQL table"""
        if not self.__supports_table(sql_table):
            if not self.quiet:
                logger.warning(f"{sql_table} sql table not supported")
            return None, None

        api_duration_dict = {}
        api_count_dict = {}

        # Note : Table names can not be variable in the execute() function. We shouldn't
        # use string functions to create the query, as they are liable to attack, but we check the name
        # of the table exists before use so should be safe
        query = f'SELECT * FROM {sql_table} WHERE (end >= ? and start < ?) or (start >= ? and start <= ?) ORDER BY start ASC'

        for row in self.__get_connection().execute(query, (str(start_time_ns), str(end_time_ns), str(start_time_ns), str(end_time_ns))):
            pid = None
            tid = None
            if 'globalTid' in row.keys():
                pid, tid = tu.convert_global_tid(row['globalTid'])
            elif 'globalPid' in row.keys():
                pid, _ = tu.convert_global_tid(row['globalPid'])

            if target_pid is not None and pid != target_pid:
                continue
            if target_tid is not None and tid != target_tid:
                continue

            name_id = row[nameId_column_name]
            name = self.get_string(name_id) if lookup_nameId_in_string_table else name_id
            if not name or name == "Unknown":
                continue
            if name not in api_duration_dict:
                api_duration_dict[name] = 0.0
                api_count_dict[name] = 0

            # Some markers have no end, and are simple events, rather than durations
            if row['end']:
                api_duration_dict[name] += row['end'] - row['start']
            api_count_dict[name] += 1

        return api_duration_dict, api_count_dict

    ####################################################
    #
    # Get markers from a table for the given time period and tid/pid
    #
    ####################################################
    def __are_events_available_in_sql_table(self,
                                            sql_table: str,
                                            start_time_ns: int,
                                            end_time_ns: int,
                                            target_pid: int) -> bool:
        """Extract events from the provided SQL table"""
        if not self.__supports_table(sql_table):
            return False

        # Note : Table names can not be variable in the execute() function. We shouldn't
        # use string functions to create the query, as they are liable to attack, but we check the name
        # of the table exists before use so should be safe
        if self.__does_sql_column_name_exist_in_table('globalTid', sql_table):
            query = f'SELECT DISTINCT globalTid FROM {sql_table} WHERE (end >= ? and start < ?) or (start >= ? and start <= ?) ORDER BY start ASC'
            for row in self.__get_connection().execute(query, (str(start_time_ns), str(end_time_ns), str(start_time_ns), str(end_time_ns))):
                pid, _ = tu.convert_global_tid(row['globalTid'])
                if pid == target_pid:
                    return True

        elif self.__does_sql_column_name_exist_in_table('globalPid', sql_table):
            query = f'SELECT DISTINCT globalPid FROM {sql_table} WHERE (end >= ? and start < ?) or (start >= ? and start <= ?) ORDER BY start ASC'
            for row in self.__get_connection().execute(query, (str(start_time_ns), str(end_time_ns), str(start_time_ns), str(end_time_ns))):
                pid, _ = tu.convert_global_tid(row['globalPid'])
                if pid == target_pid:
                    return True
#        query = f'SELECT COUNT(DISTINCT globalTid) FROM {sql_table} WHERE (end >= ? and start < ?) or (start >= ? and start <= ?) ORDER BY start ASC'
#
#        for row in self.__get_connection().execute(query, (str(start_time_ns), str(end_time_ns), str(start_time_ns), str(end_time_ns))):
#            pid = None
#            tid = None
#            if 'globalTid' in row.keys():
#                pid, tid = tu.convert_global_tid(row['globalTid'])
#            elif 'globalPid' in row.keys():
#                pid, _ = tu.convert_global_tid(row['globalPid'])
#
#            if pid == target_pid:
#                return True

        return False

    ####################################################
    #
    # Get markers from a table for the given time period and tid/pid
    #
    ####################################################
    def __get_ordered_events_from_sql_table(self,
                                            sql_table: str,
                                            name_column_name: str,
                                            name_id_column_name: str,
                                            start_time_ns: int,
                                            end_time_ns: int,
                                            target_pid: int,
                                            target_tid: int) -> List:
        """Extract events from the provided SQL table"""
        if not self.__supports_table(sql_table):
            if not self.quiet:
                logger.warning(f"{sql_table} sql table not supported")
            return None, None

        api_duration_list = []

        # Note : Table names can not be variable in the execute() function. We shouldn't
        # use string functions to create the query, as they are liable to attack, but we check the name
        # of the table exists before use so should be safe
        query = f'SELECT * FROM {sql_table} WHERE (end >= ? and start < ?) or (start >= ? and start <= ?) ORDER BY start ASC'

        for row in self.__get_connection().execute(query, (str(start_time_ns), str(end_time_ns), str(start_time_ns), str(end_time_ns))):
            pid = None
            tid = None
            if 'globalTid' in row.keys():
                pid, tid = tu.convert_global_tid(row['globalTid'])
            elif 'globalPid' in row.keys():
                pid, _ = tu.convert_global_tid(row['globalPid'])
            if target_pid is not None and pid != target_pid:
                continue
            if target_tid is not None and tid != target_tid:
                continue

            name = None
            if name_id_column_name:
                name_id = row[name_id_column_name]
                name = self.get_string(name_id)
            elif name_column_name:
                name = row[name_column_name]

            if not name or name == "Unknown":
                continue

            # Some markers have no end, and are simple events, rather than durations
            if row['end']:
                api_duration_list.append((name, row['start'], row['end'] - row['start'], None))
            else:
                api_duration_list.append((name, row['start'], 0, None))

        return api_duration_list

    ####################################################
    # Get DX12 PIX Markers given time period and tid/pid
    ####################################################
    def __get_pix_markers(self,
                          target_pid: int,
                          target_tid: int,
                          start_time_ns: int,
                          end_time_ns: int) -> (dict, dict):
        """Extract the PIX markers"""
        if not self.__supports_table('D3D12_PIX_DEBUG_API'):
            if not self.quiet:
                logger.warning("D3D12_PIX_DEBUG_API not supported")
            return None, None

        pix_marker_correlation_dict: dict[EventMarker]
        pix_marker_correlation_dict = {}

        query = '''SELECT * FROM D3D12_PIX_DEBUG_API
        WHERE start >= ? and end < ?
        ORDER BY start ASC'''

        # TODO: do we need this?
        # total_duration = 0
        for row in self.__get_connection().execute(query, (str(start_time_ns), str(end_time_ns))):
            pid, tid = tu.convert_global_tid(row['globalTid'])
            if target_pid is not None and pid != target_pid:
                continue
            if target_tid is not None and tid != target_tid:
                continue

            name_id = row['textId']
            if name_id and row['correlationId'] not in pix_marker_correlation_dict:
                em = EventMarker()
                em.id = row['nameId']
                em.start = row['start']
                em.end = row['end']
                em.correlation_id = row['correlationId']
                em.name_id = name_id
                em.tid = tid
                em.pid = pid

                pix_marker_correlation_dict[row['correlationId']] = em
            elif not name_id and row['correlationId'] in pix_marker_correlation_dict:
                pix_marker_correlation_dict[row['correlationId']].start = pix_marker_correlation_dict[row['correlationId']].end
                pix_marker_correlation_dict[row['correlationId']].end = row['start']

        # Calculate the hotspots based on the names
        pix_marker_durations = {}
        pix_marker_counts = {}
        for pm in sorted(pix_marker_correlation_dict.values(), key=operator.attrgetter('start'), reverse=False):
            name_str = self.get_string(pm.name_id)
            if name_str not in pix_marker_durations:
                pix_marker_durations[name_str] = 0
                pix_marker_counts[name_str] = 0

            pix_marker_durations[name_str] += pm.end - pm.start
            pix_marker_counts[name_str] += 1

        return pix_marker_durations, pix_marker_counts

    ####################################################
    # Get DX12 PIX Markers given time period and tid/pid
    ####################################################
    def __get_ordered_pix_markers(self,
                                  target_pid: int,
                                  target_tid: int,
                                  start_time_ns: int,
                                  end_time_ns: int) -> List:
        """Extract ordered PIX markers"""
        if not self.__supports_table('D3D12_PIX_DEBUG_API'):
            if not self.quiet:
                logger.warning("D3D12_PIX_DEBUG_API not supported")
            return None, None

        pix_marker_correlation_dict: dict[EventMarker]
        pix_marker_correlation_dict = {}

        query = '''SELECT * FROM D3D12_PIX_DEBUG_API
        WHERE start >= ? and end < ?
        ORDER BY start ASC'''

        # TODO: do we need this?
        # total_duration = 0
        for row in self.__get_connection().execute(query, (str(start_time_ns), str(end_time_ns))):
            pid, tid = tu.convert_global_tid(row['globalTid'])
            if target_pid is not None and pid != target_pid:
                continue
            if target_tid is not None and tid != target_tid:
                continue

            name_id = row['textId']
            if name_id and row['correlationId'] not in pix_marker_correlation_dict:
                em = EventMarker()
                em.id = row['nameId']
                em.start = row['start']
                em.end = row['end']
                em.correlation_id = row['correlationId']
                em.name_id = name_id
                em.tid = tid
                em.pid = pid

                pix_marker_correlation_dict[row['correlationId']] = em
            elif not name_id and row['correlationId'] in pix_marker_correlation_dict:
                pix_marker_correlation_dict[row['correlationId']].start = pix_marker_correlation_dict[row['correlationId']].end
                pix_marker_correlation_dict[row['correlationId']].end = row['start']

        # Calculate the hotspots based on the names
        pix_marker_durations = []

        for pm in sorted(pix_marker_correlation_dict.values(), key=operator.attrgetter('start'), reverse=False):
            name = self.get_string(pm.name_id)
            pix_marker_durations.append((name, pm.start, pm.end - pm.start, None))  # Name, start, duration

        return pix_marker_durations

    ####################################################
    # Get PIX frametimes
    ####################################################
    def __get_pix_frame_times(self,
                              marker_name: str,
                              start_time_ns: Optional[float] = None,
                              end_time_ns: Optional[float] = None,
                              target_pid: Optional[int] = None) -> (float, list[FrameDurations]):
        """Get the PIX marker frametimes based on the input strings"""
        if not self.__supports_table("D3D12_PIX_DEBUG_API"):
            logger.warning("D3D12_PIX_DEBUG_API sql table not supported")
            return None

        frametime_list_pix = []
        average_duration = 0

        pix_marker_correlation_dict: dict[EventMarker]
        pix_marker_correlation_dict = {}

        query = '''SELECT * FROM D3D12_PIX_DEBUG_API
        WHERE start < ? and end > ?
        ORDER BY start ASC'''

        # TODO: do we need this?
        # total_duration = 0
        # duration = first event.end ==> correlation event.start
        for row in self.__get_connection().execute(query, (str(end_time_ns), str(start_time_ns))):
            pid, tid = tu.convert_global_tid(row['globalTid'])
            if target_pid is not None and pid != target_pid:
                continue
            name_id = row['textId']
            if name_id and row['correlationId'] not in pix_marker_correlation_dict:
                em = EventMarker()
                em.id = row['nameId']
                em.start = row['start']
                em.end = row['end']
                em.correlation_id = row['correlationId']
                em.name_id = name_id
                em.tid = tid
                em.pid = pid

                pix_marker_correlation_dict[row['correlationId']] = em
            elif not name_id and row['correlationId'] in pix_marker_correlation_dict:
                pix_marker_correlation_dict[row['correlationId']].start = pix_marker_correlation_dict[row['correlationId']].end
                pix_marker_correlation_dict[row['correlationId']].end = row['start']

        for pm in sorted(pix_marker_correlation_dict.values(), key=operator.attrgetter('start'), reverse=False):
            duration = pm.end - pm.start
            if self.get_string(pm.name_id) != marker_name:
                continue

            average_duration += duration
            frametime_list_pix.append(FrameDurations(pm.start, duration))

        if len(frametime_list_pix):
            average_duration /= len(frametime_list_pix)
            return average_duration, frametime_list_pix
        return 0, None

    ####################################################
    # Get dx12 GPU workload frametimes
    ####################################################
    def __get_dx12_workload_frame_times(self,
                                        marker_name: str,
                                        start_time_ns: Optional[float] = None,
                                        end_time_ns: Optional[float] = None,
                                        target_pid: Optional[int] = None) -> (float, list[FrameDurations]):
        """Get the DX12 GPU Workload marker frametimes based on the input strings"""
        if not self.__supports_table("DX12_WORKLOAD"):
            logger.warning("DX12_WORKLOAD sql table not supported")
            return None

        query = 'SELECT * FROM DX12_WORKLOAD WHERE end >= ? and start < ? ORDER BY start ASC'

        frametime_list_dx12_workload = []
        average_duration = 0
        for row in self.__get_connection().execute(query, (str(start_time_ns), str(end_time_ns))):
            pid, tid = tu.convert_global_tid(row['globalTid'])
            if target_pid is not None and pid != target_pid:
                continue

            name_id = row['textId']
            if not name_id:
                continue

            name = self.get_string(name_id)
            if name != marker_name:
                continue

            start = row['start']
            duration = row['end'] - start

            average_duration += duration
            frametime_list_dx12_workload.append(FrameDurations(start, duration))

        if len(frametime_list_dx12_workload):
            average_duration /= len(frametime_list_dx12_workload)
            return average_duration, frametime_list_dx12_workload
        return 0, None

    ####################################################
    # Get Cuda API frametimes
    ####################################################
    def __get_cuda_api_frame_times(self,
                                   api_name_starts_with: str,
                                   kernel_starts_with: str,
                                   start_time_ns: Optional[float] = None,
                                   end_time_ns: Optional[float] = None,
                                   target_pid: Optional[int] = None) -> (float, list[FrameDurations]):
        """Get the Cuda API frametimes based on the input strings"""
        if not self.__supports_table("CUPTI_ACTIVITY_KIND_RUNTIME"):
            logger.warning("CUPTI_ACTIVITY_KIND_RUNTIME sql table not supported")
            return 0, None

        if not self.__supports_table("CUPTI_ACTIVITY_KIND_KERNEL"):
            logger.warning("CUPTI_ACTIVITY_KIND_KERNEL sql table not supported")
            return 0, None

        query = """SELECT r.start AS start, r.end AS end, r.nameId as nameId, r.globalTid as globalTid, k.demangledName as kernelName
                FROM CUPTI_ACTIVITY_KIND_RUNTIME AS r
                JOIN CUPTI_ACTIVITY_KIND_KERNEL AS k
                ON k.correlationId == r.correlationId
                AND k.globalPid == (r.globalTid & 0xFFFFFFFFFF000000)
                WHERE r.end >= ? and r.start < ? ORDER BY start ASC"""

        frametime_list = []
        average_duration = 0
        for row in self.__get_connection().execute(query, (str(start_time_ns), str(end_time_ns))):
            pid, tid = tu.convert_global_tid(row['globalTid'])
            if target_pid is not None and pid != target_pid:
                continue

            name_id = row['nameId']
            if not name_id:
                continue

            api_name = self.get_string(name_id)
            if not api_name.startswith(api_name_starts_with):
                continue

            kernel_name = self.get_string(row['kernelName'])
            if not kernel_name.startswith(kernel_starts_with):
                continue

            start = row['start']
            end = row['end']
            duration = end - start

            average_duration += duration
            frametime_list.append(FrameDurations(start, duration))

        if len(frametime_list):
            average_duration /= len(frametime_list)
            return average_duration, frametime_list
        return 0, None

    ####################################################
    # Get Cuda API frametimes
    ####################################################
    def __get_cuda_gpu_frame_times(self,
                                   starts_with: str,
                                   start_time_ns: Optional[float] = None,
                                   end_time_ns: Optional[float] = None,
                                   target_pid: Optional[int] = None) -> (float, list[FrameDurations]):
        """Get the Cuda GPU Kernel frametimes based on the input strings"""
        if not self.__supports_table("CUPTI_ACTIVITY_KIND_KERNEL"):
            logger.warning("CUPTI_ACTIVITY_KIND_KERNEL sql table not supported")
            return None

        query = 'SELECT * FROM CUPTI_ACTIVITY_KIND_KERNEL WHERE end >= ? and start < ? ORDER BY start ASC'

        frametime_list = []
        average_duration = 0
        for row in self.__get_connection().execute(query, (str(start_time_ns), str(end_time_ns))):
            pid, tid = tu.convert_global_tid(row['globalPid'])
            if target_pid is not None and pid != target_pid:
                continue

            name_id = row['demangledName']
            if not name_id:
                continue

            name = self.get_string(name_id)
            if not name.startswith(starts_with):
                continue

            start = row['start']
            duration = row['end'] - start

            average_duration += duration
            frametime_list.append(FrameDurations(start, duration))

        if len(frametime_list):
            average_duration /= len(frametime_list)
            return average_duration, frametime_list
        return 0, None

    ####################################################
    # Get dxgkrnl ranges for the given time period and tid/pid
    ####################################################
    def __get_dxgkrnl_profile_ranges(self,
                                     target_pid: int,
                                     target_tid: int,
                                     start_time_ns: int,
                                     end_time_ns: int) -> (dict, dict):
        """Extract the DxgKrnl profile ranges"""
        if not self.__supports_table('ETW_EVENTS') or not self.__supports_table('ETW_TASKS') or not self.__supports_table('ETW_PROVIDERS'):
            if not self.quiet:
                logger.warning("ETW_EVENTS/ETW_TASKS not supported")
            return None

        # We are looking for the Profiler task names. There might be other interrupts elsewhere, but this will do for now.
        profileTaskID = -1
        for row in self.__get_connection().execute('select * from ETW_TASKS where taskNameId=?', (self.get_string_id('Profiler'),)):
            profileTaskID = row['taskId']

        dxgKrnlProviderID = -1
        for row in self.__get_connection().execute('select * from ETW_PROVIDERS where providerNameId=?', (self.get_string_id('Microsoft-Windows-DxgKrnl'),)):
            dxgKrnlProviderID = row['providerId']

        query = 'SELECT * FROM ' + self.__get_etw_events_sql_table_for_query()
        query += ' WHERE timestamp >= ? and timestamp < ? and '
        query += self.__etw_task_id_column_name() + ' = ? and '
        query += self.__etw_provider_id_column_name() + ' = ? '
        query += 'ORDER BY timestamp ASC'

        regions_dict = {}
        for row in self.__get_connection().execute(query, (str(start_time_ns), str(end_time_ns), str(profileTaskID), str(dxgKrnlProviderID))):
            pid, tid = self.__get_pid_tid_from_row(row)

            if target_pid is not None and pid != target_pid:
                continue

            if target_tid is not None and tid != target_tid:
                continue

            event_id = row[self.__etw_event_id_column_name()]
            opcode = row['opcode']
            timestamp = row['timestamp']
            data_dict = json.loads(row['data'])
            function_name = data_dict['Function']

            # we've found a new event START
            # Need to ensure the regions are thread specific - we'll accumulate them after.
            if function_name:
                key = (function_name, tid)
                if event_id == 105 and opcode == 1:  # Start event
                    if key not in regions_dict:
                        regions_dict[key] = [Frames(start=timestamp, end=0)]  # Add a new list of regions
                    else:
                        regions_dict[key].append(Frames(start=timestamp, end=0))
                elif event_id == 106 and opcode == 2:  # Stop event
                    if key not in regions_dict:
                        regions_dict[key] = [Frames(start=start_time_ns, end=timestamp)]  # Add a new region from the start of the time period to the current timestamp
                    else:  # Update the last region added
                        last_frame_index = len(regions_dict[key]) - 1
                        if regions_dict[key][last_frame_index].end == 0:
                            regions_dict[key][last_frame_index].end = timestamp
                        else:
                            logger.warning(f"Profile region already has an end. Ignoring {function_name} for tid {tid}")

        # Now we've got a list of regions, lets generate the durations.
        region_durations_dict = {}
        region_counts_dict = {}

        # Check for regions that don't end
        for region_key, region_type in regions_dict.items():
            for region in region_type:
                if region.end == 0:
                    region.end = end_time_ns

        collapsed_regions_dict = {}
        # Collapse per thread functions where applicable
        for region_key, region_type in regions_dict.items():
            (region_function_name, region_tid) = region_key
            if region_function_name not in collapsed_regions_dict:
                collapsed_regions_dict[region_function_name] = regions_dict[region_key]
            else:
                collapsed_regions_dict[region_function_name] += regions_dict[region_key]

        # Calc durations
        for region_key, region_type in collapsed_regions_dict.items():
            duration = 0
            region_counts_dict[region_key] = 0

            # Walk the list of regions
            for region in region_type:
                region_counts_dict[region_key] += 1
                duration += region.end - region.start

            region_durations_dict[region_key] = duration

        return region_durations_dict, region_counts_dict

    ####################################################
    # Get hot ETW events for the given time period and tid/pid
    ####################################################
    def __get_etw_events(self,
                         target_pid: int,
                         target_tid: int,
                         start_time_ns: int,
                         end_time_ns: int) -> (dict, dict):
        """Extract the ETW events"""
        deconstruct_profiler_task = False
        deconstruct_block_thread_task = True

        if not self.__supports_table('ETW_EVENTS') and not self.__supports_table('ETW_TASKS'):
            if not self.quiet:
                logger.warning("ETW_EVENTS/ETW_TASKS not supported")
            return None, None

        etw_task_names = {}

        profileTaskID = -1
        for row in self.__get_connection().execute('select * from ETW_TASKS where taskNameId=?', (self.get_string_id('Profiler'),)):
            profileTaskID = row['taskId']

        blockThreadTaskID = -1
        for row in self.__get_connection().execute('select * from ETW_TASKS where taskNameId=?', (self.get_string_id('BlockThread'),)):
            blockThreadTaskID = row['taskId']

        for row in self.__get_connection().execute('SELECT * FROM ETW_TASKS'):
            if row['taskId'] not in etw_task_names:
                etw_task_names[row['taskId']] = self.get_string(row['taskNameId'])

        query = 'SELECT * FROM ' + self.__get_etw_events_sql_table_for_query()
        query += ' WHERE timestamp >= ? and timestamp < ?'
        query += ' ORDER BY timestamp ASC'

        etw_sample_dict = {}
        for row in self.__get_connection().execute(query, (str(start_time_ns), str(end_time_ns))):
            pid, tid = self.__get_pid_tid_from_row(row)

            if target_pid is not None and pid != target_pid:
                continue

            if target_tid is not None and tid != target_tid:
                continue

            name_id = row[self.__etw_task_id_column_name()]
            # If a profiler task, try to extract the function name to use as the task name
            if deconstruct_profiler_task and name_id == profileTaskID:
                data_dict = json.loads(row['data'])
                if 'Function' in data_dict:
                    name_id = data_dict['Function']

            # If a block thread task, try to extract the reason to use as the task name
            if deconstruct_block_thread_task and name_id == blockThreadTaskID:
                data_dict = json.loads(row['data'])
                if 'Reason' in data_dict:
                    name_id = 'BlockThread:' + data_dict['Reason']

            if name_id not in etw_sample_dict:
                etw_sample_dict[name_id] = 0

            etw_sample_dict[name_id] += 1

        # Slight abuse of the return type here - return the task names as a dict
        return etw_sample_dict, etw_task_names  # TODO - don't return the names, just encode them as keys

    ####################################################
    # Get total GPU frame time
    ####################################################
    # Unused
#    def __get_total_gpu_duration(self,
#                                start: Optional[float] = None,
#                                end: Optional[float] = None,
#                                target_pid: Optional[int] = None) -> float:
#        gpu_timeslice_list = []
#        if not self.__supports_table('DX12_WORKLOAD'):
#            return 0
#
#        total_gpu_time = 0
#        for row in self.__get_connection().execute('''SELECT start, end, globalTid FROM DX12_WORKLOAD ORDER BY start ASC'''):
#            pid, tid = tu.convert_global_tid(row['globalTid'])
#            if (target_pid is not None) and (target_pid != pid):
#                continue
#            # Trim
#            if start is not None and row['start'] < start:
#                continue
#            if end is not None and row['end'] > end:
#                continue
#
#            gpu_timeslice_list.append(Frames(row['start'], row['end']))
#
#        # Use a dequeue to track all asynchronous usage of GPU engines
#        # to get an accurate utilisation number
#        d = deque()
#        prev = None
#        for t in gpu_timeslice_list:
#            # remove a timeslice if it is no longer relevant
#            gc = []
#            for e in d:
#                if e.end < t.start:
#                    gc.append(e)
#
#            # sort on ascending end points
#            gc.sort(key=sort_on_end)
#            for g in gc:
#                time = g.end - prev
#                total_gpu_time += time
#                prev = g.end
#                d.remove(g)
#
#            # skip if this is the only thread
#            count = len(d)
#            if count != 0:
#                time = t.start - prev
#                total_gpu_time += time
#                prev = t.start
#            else:
#                prev = t.start
#
#            # put the current timeslice on the deque
#            d.append(t)
#
#        return total_gpu_time
