/**
 * Copyright (c) 2023-2025 itemis AG - All rights Reserved
 * Unauthorized copying of this file, via any medium is strictly prohibited
 */
package com.yakindu.sctunit.coverage.calculation

import com.yakindu.sct.model.sgraph.Statechart
import com.yakindu.sct.model.sgraph.Transition
import com.yakindu.sct.model.sgraph.Vertex
import com.yakindu.sct.simulation.core.coverage.Measurement
import com.yakindu.sct.simulation.core.coverage.Measurement.StateTransitionCoverage
import com.yakindu.sct.simulation.core.coverage.Measurement.TestCaseCount
import com.yakindu.sct.simulation.core.coverage.MeasurementExtension
import com.yakindu.sct.simulation.core.coverage.StatechartMeasurementBuilder
import com.yakindu.sctunit.sCTUnit.SCTUnitClass
import com.yakindu.sctunit.sCTUnit.SCTUnitSuite
import java.util.List

import static extension com.yakindu.sct.model.sgraph.util.SubchartDFS.*

/**
 * 
 * @author finlay weegen - Initial contribution and API
 *
 */
class SCTUnitMeasureCalculator {
	
	protected extension MeasurementExtension = new MeasurementExtension
	protected extension StatechartMeasurementBuilder = new StatechartMeasurementBuilder
	
	static val SUITE_MEASUREMENTS = "test suite coverage" 
	
	
	// ------- calculating all measures
	
	def void defineMeasures(Measurement it) {
		defineTestCaseCount
		defineCoverage	
	}
	
	
	// ------- calculating test case count 
	
	def TestCaseCount defineTestCaseCount(Measurement it) {
		defineTestCaseCount(it.subject)
	}
	
	def protected dispatch TestCaseCount defineTestCaseCount(Measurement it, SCTUnitSuite subject) {

		val testCaseCount = new TestCaseCount(
								it.findAll[m | m.subject instanceof SCTUnitClass]
								  .map[ testCaseCount !==  null ? testCaseCount.testCases : 0 ]
								  .reduce[c1, c2| c1 + c2])
		
		it.measures.removeIf[ it instanceof TestCaseCount ]
		it.measures += testCaseCount
		return testCaseCount
	}
	
	def protected dispatch TestCaseCount defineTestCaseCount(Measurement it, Object subject) {
		
		return testCaseCount
	}
	
	// ------- calculating state transition coverage
	
	def StateTransitionCoverage defineCoverage(Measurement it) {
		it.defineCoverage(it.subject)
	}

	def protected dispatch StateTransitionCoverage defineCoverage(Measurement it, SCTUnitSuite subject) {
		val aggregatedMeasure = aggregateStatechartMeasurements

		val coverage = aggregatedMeasure.defineCoverage + StateTransitionCoverage.ZERO
		aggregatedMeasure.subject = subject
		aggregatedMeasure.type = SCTUnitSuite
		measures += coverage
					
		return coverage
	}

	def protected dispatch StateTransitionCoverage defineCoverage(Measurement it, SCTUnitClass subject) {
		
		val statechartMeasurement = it.forSubject(subject.statechart)
		statechartMeasurement.defineCoverage
		
		val subcharts = subject.statechart.subcharts.map[p|p.statechart].toSet
		
		subcharts.forEach[ sub |
			val subMeasurement = it.children.findFirst[ c | 
				c.subject === sub
			]
			subMeasurement.addVisitsOfSubjectFrom(statechartMeasurement)
			subMeasurement.defineCoverage
		]
		
		combineAndUpdateChildMeasures(it, StateTransitionCoverage.ZERO)
	}

	def protected dispatch StateTransitionCoverage defineCoverage(Measurement it, Statechart statechart) {

		children.filter[subject instanceof Statechart].forEach[
			defineCoverage
		]

		combineAndUpdateChildMeasures(
			StateTransitionCoverage.ZERO,
			children.filter[!(subject instanceof Statechart)].toList
		)
		
	}
	
	def protected dispatch StateTransitionCoverage defineCoverage(Measurement it, Object subject) {
		combineAndUpdateChildMeasures(it, StateTransitionCoverage.ZERO)
	}
	
	def protected dispatch StateTransitionCoverage defineCoverage(Measurement it, Void v) {
		combineAndUpdateChildMeasures(it, StateTransitionCoverage.ZERO)
	}

	def protected dispatch StateTransitionCoverage defineCoverage(Measurement it, Vertex vertex) {
		combineAndUpdateChildMeasures(it, if(it.visits.count > 0) StateTransitionCoverage.FULLY_COVERED else StateTransitionCoverage.NOT_COVERED)
	}

	def protected dispatch StateTransitionCoverage defineCoverage(Measurement it, Transition transition) {
		combineAndUpdateChildMeasures(it, if(it.visits.count > 0) StateTransitionCoverage.FULLY_COVERED else StateTransitionCoverage.NOT_COVERED)
	}	

	protected def StateTransitionCoverage combineAndUpdateChildMeasures(Measurement it, StateTransitionCoverage overallMeasurement) {
		it.combineAndUpdateChildMeasures(overallMeasurement, it.children)
	}
	
	protected def StateTransitionCoverage combineAndUpdateChildMeasures(Measurement it, StateTransitionCoverage overallMeasurement, List<Measurement> childMeasurements) {
		var result = overallMeasurement
		for (child : childMeasurements) {
			result += child.defineCoverage
		}
		measures += result
		return result
	}
	
	
	// ------- aggregate visits
	
	
	def protected aggregateStatechartMeasurements(Measurement it) {
	
		var aggregatedMeasurements = children.filter[SUITE_MEASUREMENTS == name].head
	
		if ( aggregatedMeasurements !== null ) children.remove(aggregatedMeasurements)
		
		val statecharts = it.findAll[it.subject instanceof Statechart]
							.map[subject as Statechart]
							.toSet
							
		aggregatedMeasurements = new Measurement => [ m |
			m.name = SUITE_MEASUREMENTS
		]
							
		aggregatedMeasurements.children.addAll(
			statecharts.map[ sc |
				val aggregated = sc.buildMeasurement
			
				it.findAll[it.subject === sc]
			      .forEach[ scm | 
			      	scm.findAll[visits.count !== 0]
			      	   .forEach[ m | aggregated.forSubject(m.subject).visits.count += m.visits.count]
			      ] 
			    aggregated
			])
			
		children.add(0, aggregatedMeasurements)
		
		aggregatedMeasurements			
	}
	
	def protected addVisitsOfSubjectFrom(Measurement to, Measurement from) {
		from.findAll[ m | m.subject === to.subject]
		.forEach[ scm | 
			scm
			.findAll[visits.count !== 0]
			.forEach[ m | 
				to.forSubject(m.subject).visits.count += m.visits.count
			]
		]
		
	}

}