/**
 * Copyright (c) 2023-2025 itemis AG - All rights Reserved
 * Unauthorized copying of this file, via any medium is strictly prohibited
 * 
 * Contributors:
 * 	Axel Terfloth - itemis AG
 *  Laszlo Kovacs - itemis AG
 * 
 */
package com.itemis.create.sctunit.transformation

import com.google.inject.Inject
import com.itemis.create.base.generator.core.Transformation
import com.itemis.create.base.generator.core.concepts.TimedInterface
import com.itemis.create.base.generator.core.concepts.VirtualTimer
import com.itemis.create.sunit.SUnitBuilder
import com.itemis.create.sunit.SUnitLib
import com.yakindu.base.base.NamedElement
import com.yakindu.base.expressions.ExpressionBuilder
import com.yakindu.base.expressions.expressions.ArgumentExpression
import com.yakindu.base.expressions.expressions.ElementReferenceExpression
import com.yakindu.base.expressions.expressions.ExpressionsFactory
import com.yakindu.base.expressions.expressions.IntLiteral
import com.yakindu.base.expressions.expressions.Literal
import com.yakindu.base.expressions.expressions.PrimitiveValueExpression
import com.yakindu.base.types.ComplexType
import com.yakindu.base.types.Declaration
import com.yakindu.base.types.EnumerationType
import com.yakindu.base.types.Expression
import com.yakindu.base.types.Operation
import com.yakindu.base.types.Package
import com.yakindu.base.types.Part
import com.yakindu.base.types.Property
import com.yakindu.base.types.TypeBuilder
import com.yakindu.base.types.TypedDeclaration
import com.yakindu.sct.generator.core.codemodel.StatemachineClass
import com.yakindu.sct.model.sgraph.Statechart
import com.yakindu.sct.model.stext.stext.ActiveStateReferenceExpression
import com.yakindu.sct.model.stext.stext.VariableDefinition
import com.yakindu.sct.types.resource.Statechart2TypeTransformation
import com.yakindu.sctunit.sCTUnit.AssertionStatement
import com.yakindu.sctunit.sCTUnit.CodeBlock
import com.yakindu.sctunit.sCTUnit.EnterExpression
import com.yakindu.sctunit.sCTUnit.ExitExpression
import com.yakindu.sctunit.sCTUnit.ExpressionStatement
import com.yakindu.sctunit.sCTUnit.IfStatement
import com.yakindu.sctunit.sCTUnit.LoopStatement
import com.yakindu.sctunit.sCTUnit.MockReturnStatement
import com.yakindu.sctunit.sCTUnit.MockingStatement
import com.yakindu.sctunit.sCTUnit.ProceedExpression
import com.yakindu.sctunit.sCTUnit.ProceedUnit
import com.yakindu.sctunit.sCTUnit.ReturnStatement
import com.yakindu.sctunit.sCTUnit.SCTUnitClass
import com.yakindu.sctunit.sCTUnit.SCTUnitOperation
import com.yakindu.sctunit.sCTUnit.StatechartActiveExpression
import com.yakindu.sctunit.sCTUnit.StatechartFinalExpression
import com.yakindu.sctunit.sCTUnit.TestPackage
import com.yakindu.sctunit.sCTUnit.TriggerWithoutEventExpression
import com.yakindu.sctunit.sCTUnit.VariableDefinitionStatement
import com.yakindu.sctunit.sCTUnit.VerifyCalledStatement
import java.util.LinkedList
import org.eclipse.emf.ecore.EObject
import org.eclipse.emf.ecore.util.EcoreUtil
import org.eclipse.xtext.EcoreUtil2

import static com.yakindu.sct.generator.core.codemodel.StatemachineClass.STATEMACHINE_CLASS_TYPE
import com.yakindu.sct.model.stext.concepts.StatechartAnnotations

/**
 * This class transforms the SCTUnit language to sunit which is based on types & expressions
 */
class SCTUnit2SUnitTransformation extends Transformation<TestPackage, Package> {
	
	@Inject protected extension Statechart2TypeTransformation
	
	@Inject protected extension TypeBuilder
	@Inject protected extension ExpressionBuilder
	@Inject protected extension SUnitBuilder
	@Inject protected extension SUnitLib
	@Inject protected extension VirtualTimer
	@Inject protected extension TimedInterface
	@Inject protected extension StatemachineClass
	//TODO: Should not be a dependency as it is dependant of SGraph
	@Inject protected extension StatechartAnnotations
	
	protected extension ExpressionsFactory expFactory = ExpressionsFactory.eINSTANCE

	
	override protected toTarget(TestPackage it) {
			
		it.toSUnit as Package		
	}
	
	override protected modifyTarget(Package target) {
	}

	def dispatch Declaration toSUnit(TestPackage sctunitPackage) {
	

		_package => [ root |
			
			target = root // already set the target here to make it available within the subsequent transformation steps
			
			root.member += defineSunitPackage
			
			_sunit(sctunitPackage.name) => [
				root.member += it	
				sctunitPackage.member.forEach[ member | 
					it.member += member.toSUnit
					it.addStatechartNamespaceAsImport(member)
				]
				transformedRootFrom(sctunitPackage)
			]		
			
		]		

	}
	
	def dispatch addStatechartNamespaceAsImport(Package sunitPackage, SCTUnitClass testClass){
		// Add namespace of all tested/used statecharts
		sunitPackage.member.forEach[m |
			if(m instanceof ComplexType){
				m.features.filter[statechartType].forEach[sctType |
					if((sctType.originSource as Statechart).namespace !== null && !sunitPackage.imports.contains((sctType.originSource as Statechart).namespace))
						sunitPackage._import((sctType.originSource as Statechart).namespace)
				]
			}
		]
	}
	
	def dispatch addStatechartNamespaceAsImport(Package SUnitPackage, EObject it){}
	
	def dispatch Declaration toSUnit(SCTUnitClass sctunitClass) {
		
		_testClass(sctunitClass.name) => [
			
			sctunitClass.statechart.createTypeDescription
			val statemachineType = sctunitClass.statechart.statechartType
			
			features += statemachineType.createCopy => [
				_annotate(STATEMACHINE_CLASS_TYPE)
			]
			
			val statemachine = defineStatemachineVar(statemachineType)
			
			moveStateEnumType(sctunitClass.statechart)
			
			
			sctunitClass.variableDefinitions.forEach[ f | it.features += f.definition.toSUnit]
						
			sctunitClass.features.forEach[ f | it.features += f.toSUnit]
			
			if(sctunitClass.statechart.isCycleBased)
				it.annotations += sctunitClass.statechart.getAnnotationOfType(StatechartAnnotations.CYCLE_BASED_ANNOTATION).copy
			
			substituteEnterExpressions(statemachine)
			substituteExitExpressions(statemachine)
			substituteTriggerWithoutEventExpressions(statemachine)
			
			createTimedInterface
			createVirtualTimer
			
			substituteProceedExpressions(createVirtualTimerVar)
			
			substituteStatechartActiveExpressions(statemachine)
			substituteStatechartFinalExpressions(statemachine)
			substituteActiveStateReferenceExpressions(statemachine)
			substituteStatechartMemberCalls(statemachine)
			
			statemachineType.transformedRootFrom(sctunitClass.statechart)
		]
	}
	
	def dispatch Declaration toSUnit(SCTUnitOperation sctunitOperation) {
		
		_public(_op(sctunitOperation.name)) => [
			
			if (sctunitOperation.isTest) _test
			if (sctunitOperation.ignored) _ignore

			typeSpecifier = sctunitOperation.typeSpecifier
			sctunitOperation.parameters.forEach[ p | it._param(p.name, p.typeSpecifier) ]
			
			implementation = sctunitOperation.body.toExpression
			
			transformedFrom(sctunitOperation)
		]
	}
	
	def dispatch Declaration toSUnit(VariableDefinition varDef) {
		
		_variable(varDef.name, varDef.typeSpecifier.copy) => [
			initialValue = varDef.initialValue.copy
			transformedFrom(varDef)
		]
	}
	
	
	def dispatch Expression toExpression(AssertionStatement ts) {
		_call(target.sunitAssert)._with(ts.expression.copy,if(ts.errorMsg !== null) ts.errorMsg._string)
	}
	
	def dispatch Expression toExpression(MockingStatement ts) {
		val argValues = (ts.reference as ElementReferenceExpression).arguments.map[value]
		val mockExpr = _call(target.sunitMock)._with((ts.reference as ElementReferenceExpression).reference._ref._with(argValues))
		if(ts instanceof MockReturnStatement && (ts as MockReturnStatement).value !== null)
			return mockExpr._call(target.sunitReturn)._with((ts as MockReturnStatement).value)
		else
			return mockExpr
	}
	
	def dispatch Expression toExpression(VerifyCalledStatement ts) {
		val argValues = (ts.reference as ArgumentExpression).arguments.map[value]
		if(ts.times !== null)
			_call(target.sunitCalled)._with((ts.reference as ArgumentExpression)._with(argValues))._call(target.sunitTimes)._with(ts.times.value._integer)
		else if(ts.negated)
			_call(target.sunitCalled)._with((ts.reference as ArgumentExpression)._with(argValues))._call(target.sunitTimes)._with(0._integer)
		else
			_call(target.sunitCalled)._with((ts.reference as ArgumentExpression)._with(argValues))
	}

	def dispatch Expression toExpression(CodeBlock ts) {
		_block => [
			ts.code.forEach[ s | it.expressions += s.toExpression ]
		]
	}

	def dispatch Expression toExpression(Void ts) {
		null
	}


	def dispatch Expression toExpression(ExpressionStatement ts) {
		ts.expression.copy
	}
	
	def dispatch Expression toExpression(IfStatement ts) {
		_if(ts.condition.copy, ts.then.toExpression, ts.^else.toExpression)
	}

	def dispatch Expression toExpression(LoopStatement ts) {
		_while(ts.guard.copy, ts.body.toExpression)
	}
	
	def dispatch Expression toExpression(ReturnStatement ts) {
		_return(ts.returnValue.copy)	
	}
	
	def dispatch Expression toExpression(VariableDefinitionStatement ts) {
		
		_declare(ts.definition.toSUnit)	
	}
	
	
	def defineStatemachineVar(ComplexType testClass, ComplexType statechartType ) {
		val stmPart = _protected(_part("statemachine", statechartType)) => [
			testClass.features += it
		]
		testClass.defineChildStatemachineTypes(statechartType)
		testClass.defineChildStmParts(statechartType,stmPart,newLinkedList,statechartType.name)
		stmPart	
	}
	
	def void defineChildStmParts(ComplexType testClass, ComplexType statechartType, Part parent, LinkedList<String> typePath, String rootTypeName) {
		val currentKey = statechartType.name
		typePath.addLast(currentKey)
	
		val directChildren = statechartType.eAllContents
			.filter(TypedDeclaration)
			.filter[td | td.type instanceof ComplexType && (td.type as ComplexType).isStatechartType]
			.toList
	
		for (td : directChildren) {
			val childType = td.type as ComplexType
			val childKey = childType.name
	
			// Only proceed if the child is not the root
			if (childType.name != rootTypeName) {
				// Always create the part
				val part = _protected(_part(parent.name + "_" + td.name + "_" + childType.name, childType) => [
					_annotate(Statechart2TypeTransformation.CHILD_STATECHART_ANNOTATION)
					traceOrigin(parent)
					traceOrigin(td)
				])
				testClass.features += part
	
				// Recurse only if not already visited
				if (!typePath.contains(childKey)) {
					defineChildStmParts(testClass, childType, part, typePath, rootTypeName)
				}
			}
		}
	
		typePath.removeLast()
	}
	
	def void defineChildStatemachineTypes(ComplexType testClass, ComplexType statechartType){
		
		statechartType.eAllContents.filter(TypedDeclaration).filter[f | f.type.isStatechartType ].forEach[td |
			if(testClass.features.filter(ComplexType).filter[!(it instanceof EnumerationType) && (originSource as NamedElement).name == (td.type.originSource as NamedElement).name].nullOrEmpty && testClass.features.filter[it.origin === td.type].nullOrEmpty){
				val stmType = td.type.createCopy as ComplexType
				val enumTypeForStm = (td.type.eContainer as Package).member.filter(EnumerationType).head
				stmType.features += enumTypeForStm.createCopy
				testClass.features += stmType				
				testClass.defineChildStatemachineTypes(td.type as ComplexType)
			}				
		]
	}
	
	def substituteEnterExpressions(ComplexType sunitClass, Property statemachine) {

		sunitClass.eAllContents.filter(EnterExpression).forEach[
			EcoreUtil.replace(it, statemachine._ref._dot(statemachine.enterMethod))
		]
	}
	
	def substituteExitExpressions(ComplexType sunitClass, Property statemachine) {

		sunitClass.eAllContents.filter(ExitExpression).forEach[
			EcoreUtil.replace(it, statemachine._ref._dot(statemachine.exitMethod))
		]
	}
	
	def substituteStatechartFinalExpressions(ComplexType sunitClass, Property statemachine) {

		sunitClass.eAllContents.filter(StatechartFinalExpression).forEach[
			EcoreUtil.replace(it, statemachine._ref._dot(statemachine.isFinalMethod))
		]
	}
	
	def substituteStatechartActiveExpressions(ComplexType sunitClass, Property statemachine) {

		sunitClass.eAllContents.filter(StatechartActiveExpression).forEach[
			EcoreUtil.replace(it, statemachine._ref._dot(statemachine.isActiveMethod))
		]
	}
	
	def substituteTriggerWithoutEventExpressions(ComplexType sunitClass, Property statemachine) {

		sunitClass.eAllContents.filter(TriggerWithoutEventExpression).forEach[
			EcoreUtil.replace(it, statemachine._ref._dot(statemachine.triggerWithoutEventMethod))
		]
	}
	
	def substituteProceedExpressions(ComplexType sunitClass, Property timer) {
		sunitClass.eAllContents.filter(ProceedExpression).forEach[ pe |
			EcoreUtil.replace(pe,
				if(timer.type !== virtualTimerType) _block
				else if(pe.unit == ProceedUnit.CYCLES) timer._ref._dot(timer.type.cycleLeapOp)._with(pe.value)
				else timer._ref._dot(timer.type.timeLeapOp)._with(pe.value.proceedValue(pe.unit))
			)
		]		
	}
	
	def dispatch Expression proceedValue(Expression it, ProceedUnit unit){
		if(unit == ProceedUnit.MILLISECOND)
			return it
		else
			createNumericalMultiplyDivideExpression => [ nmde |
				nmde.leftOperand = it
				nmde.rightOperand = createPrimitiveValueExpression => [ pv |
					pv.value = 1._integer.value.calculateProceedValue(unit)
				]
			]
	}
	
	def dispatch Expression proceedValue(ElementReferenceExpression it, ProceedUnit unit){
		if(unit == ProceedUnit.MILLISECOND)
			return it
		else
			createNumericalMultiplyDivideExpression => [ nmde |
				nmde.leftOperand = it
				nmde.rightOperand = createPrimitiveValueExpression => [ pv |
					pv.value = 1._integer.value.calculateProceedValue(unit)
				]
			]
	}
	
	def dispatch Expression proceedValue(PrimitiveValueExpression it, ProceedUnit unit){
		if(unit == ProceedUnit.MILLISECOND)
			return it
		else
			createPrimitiveValueExpression => [pv |
				pv.value = value.calculateProceedValue(unit)
			]		
	}
	
	def dispatch IntLiteral calculateProceedValue(IntLiteral it, ProceedUnit unit){
		createIntLiteral => [ il |
			il.value =  
			if(unit == ProceedUnit.SECOND)
				value * 1000
			else if(unit == ProceedUnit.MICROSECOND)
				value / 1000
			//Nano second
			else
				value / 1000000
		]
	}
	
	def dispatch IntLiteral calculateProceedValue(Literal it, ProceedUnit unit){}
	
	def substituteActiveStateReferenceExpressions(ComplexType sunitClass, Property statemachine) {

		sunitClass.eAllContents.filter(ActiveStateReferenceExpression).forEach[
			val isStateActiveMethod = statemachine.isStateActiveMethod
			val sourceState = it.value
			val targetState = (statemachine.type.origin as Statechart).stateEnumType.enumerator.filter[ e | e.origin === sourceState].head
			
			EcoreUtil.replace(it, statemachine._ref._call(isStateActiveMethod)._with(targetState._ref))
		]
	}	
	
	def substituteStatechartMemberCalls(ComplexType sunitClass, Property statemachine) {

		sunitClass.eAllContents.filter(ElementReferenceExpression).filter[ reference.isStatechartMember &&
			 !it.mockedOperation
		].forEach[
			EcoreUtil.replace(it, statemachine._ref._dot(
				it.reference
			))
		]
	}
	
	def protected mockedOperation(ElementReferenceExpression it){
		EcoreUtil2.getAllContainers(it).filter(ElementReferenceExpression).exists[o | o.reference === target.sunitMock || o.reference === target.sunitCalled]
	}
	
	
	def void moveStateEnumType(Statechart sc) {
		sc.statechartType.features += sc.stateEnumType
		sc.stateEnumType.name = "State"
	}
	
	
	def isTest(SCTUnitOperation it) {
		! annotation.filter[ "Test" == type.name ].isEmpty
	}
	
	def ignored(SCTUnitOperation it) {
		! annotation.filter[ "Ignore" == type.name ].isEmpty
	}
	
	def isStatechartMember(EObject it) {
		EcoreUtil.getRootContainer(it) instanceof Statechart	
	}
	
	def isTestMember(EObject it) {
		EcoreUtil.getRootContainer(it) instanceof TestPackage	
	}
	
	// TODO: use concept for enter method
	def Operation enterMethod(EObject it) {
		it.lookupMember("enter") as Operation
	}	
	
	// TODO: use concept for exit method	
	def Operation exitMethod(EObject it) {
		it.lookupMember("exit") as Operation
	}	

	def Operation isFinalMethod(EObject it) {
		it.lookupMember("isFinal") as Operation
	}	
	
	def Operation isActiveMethod(EObject it) {
		it.lookupMember("isActive") as Operation
	}	
	
	// TODO: use concept for TriggerWithoutEventMethod 
	def Operation triggerWithoutEventMethod(EObject it) {
		it.lookupMember("triggerWithoutEvent") as Operation
	}	
	
	// TODO: use concept for TriggerWithoutEventMethod 
	def Operation isStateActiveMethod(EObject it) {
		it.lookupMember("isStateActive") as Operation
	}	
	
	
	def dispatch Declaration lookupMember(EObject it, String name) {
		null
	}	
	
	def dispatch Declaration lookupMember(Property it, String name) {
		type.lookupMember(name)
	}	
	
	def dispatch Declaration lookupMember(ComplexType it, String name) {
		val method = features.filter(Declaration).filter[name == it.name].head

		if (method !== null) 
			return method 
		else 
			for (stype : superTypes.map[type]) {
				var m = stype.lookupMember(name)
				if (m !== null)
					return m
			}
		return null
	}	
}
