Skip to content

Commit

Permalink
feat: Kotlin psi structure for FIM. (#829)
Browse files Browse the repository at this point in the history
* Add Kotlin dependency and configuration file

* Add collect dependency structure option to code completion settings

* Add Kotlin PSI structure analysis and serialization components

* Add repository name and dependencies structure to InfillRequest and update prompt template for QWEN_2_5

---------

Co-authored-by: a.iudin <[email protected]>
  • Loading branch information
aaudin90 and a.iudin authored Jan 29, 2025
1 parent 2714f80 commit 72f31bf
Show file tree
Hide file tree
Showing 22 changed files with 685 additions and 12 deletions.
4 changes: 3 additions & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ intellij {
pluginName.set(properties("pluginName"))
version.set(properties("platformVersion"))
type.set(properties("platformType"))
plugins.set(listOf("java", "PythonCore:241.14494.240", "Git4Idea"))
plugins.set(listOf("java", "PythonCore:241.14494.240", "Git4Idea", "org.jetbrains.kotlin"))
}

changelog {
Expand All @@ -62,6 +62,8 @@ dependencies {
// vulnerable transitive dependency
exclude(group = "org.jsoup", module = "jsoup")
}
implementation(kotlin("stdlib"))
implementation(kotlin("reflect"))
implementation(libs.jsoup)
implementation(libs.commons.text)
implementation(libs.jtokkit)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package ee.carlrobert.codegpt.codecompletions

import ee.carlrobert.codegpt.codecompletions.psi.structure.ClassStructureSerializer
import org.jetbrains.kotlin.utils.addToStdlib.ifNotEmpty

enum class InfillPromptTemplate(val label: String, val stopTokens: List<String>? = listOf("\n\n")) {

OPENAI("OpenAI") {
Expand Down Expand Up @@ -71,16 +74,37 @@ enum class InfillPromptTemplate(val label: String, val stopTokens: List<String>?
override fun buildPrompt(infillDetails: InfillRequest): String {
val infillPrompt =
"<|fim_prefix|> ${infillDetails.prefix} <|fim_suffix|>${infillDetails.suffix} <|fim_middle|>"
return if (infillDetails.context == null || infillDetails.context.contextElements.isEmpty()) {
infillPrompt
} else {
"<|repo_name|>${infillDetails.context.getRepoName()}\n" +
infillDetails.context.contextElements.map {
"<|file_sep|>${it.filePath()} \n" +
it.text()
}.joinToString("") { it + "\n" } +
"<|file_sep|>${infillDetails.context.enclosingElement.filePath()} \n" +
infillPrompt

return when {
infillDetails.dependenciesStructure != null -> {
"<|repo_name|>${infillDetails.repositoryName}\n" +
infillDetails.dependenciesStructure.joinToString(separator = "\n", prefix = "\n") {
"<|file_sep|>${it.name.value}\n${ClassStructureSerializer.serialize(it)}\n"
} +
infillDetails.context?.contextElements?.ifNotEmpty {
map {
"<|file_sep|>${it.filePath()} \n" +
it.text()
}.joinToString("") { it + "\n" } +
"<|file_sep|>${infillDetails.context.enclosingElement.filePath()} \n"
} +
infillPrompt

}

infillDetails.context != null && infillDetails.context.contextElements.isNotEmpty() -> {
"<|repo_name|>${infillDetails.context.getRepoName()}\n" +
infillDetails.context.contextElements.map {
"<|file_sep|>${it.filePath()} \n" +
it.text()
}.joinToString("") { it + "\n" } +
"<|file_sep|>${infillDetails.context.enclosingElement.filePath()} \n" +
infillPrompt
}

else -> {
infillPrompt
}
}
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import com.intellij.psi.PsiElement
import ee.carlrobert.codegpt.EncodingManager
import ee.carlrobert.codegpt.codecompletions.psi.filePath
import ee.carlrobert.codegpt.codecompletions.psi.readText
import ee.carlrobert.codegpt.codecompletions.psi.structure.models.ClassStructure

const val MAX_PROMPT_TOKENS = 256

Expand All @@ -15,6 +16,8 @@ class InfillRequest private constructor(
val suffix: String,
val caretOffset: Int,
val fileDetails: FileDetails?,
val repositoryName: String?,
val dependenciesStructure: Set<ClassStructure>?,
val context: InfillContext?,
val stopTokens: List<String>,
) {
Expand All @@ -27,6 +30,8 @@ class InfillRequest private constructor(
private val caretOffset: Int
private var fileDetails: FileDetails? = null
private var additionalContext: String? = null
private var repositoryName: String? = null
private var dependenciesStructure: Set<ClassStructure>? = null
private var context: InfillContext? = null
private var stopTokens: List<String>

Expand Down Expand Up @@ -61,6 +66,12 @@ class InfillRequest private constructor(
fun additionalContext(additionalContext: String) =
apply { this.additionalContext = additionalContext }

fun addRepositoryName(repositoryName: String) =
apply { this.repositoryName = repositoryName }

fun addDependenciesStructure(dependenciesStructure: Set<ClassStructure>) =
apply { this.dependenciesStructure = dependenciesStructure }

fun context(context: InfillContext) = apply { this.context = context }

private fun getStopTokens(type: CompletionType): List<String> {
Expand Down Expand Up @@ -94,6 +105,8 @@ class InfillRequest private constructor(
suffix,
caretOffset,
fileDetails,
repositoryName,
dependenciesStructure,
context,
stopTokens,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import com.intellij.refactoring.suggested.startOffset
import ee.carlrobert.codegpt.EncodingManager
import ee.carlrobert.codegpt.codecompletions.psi.CompletionContextService
import ee.carlrobert.codegpt.codecompletions.psi.readText
import ee.carlrobert.codegpt.codecompletions.psi.structure.PsiStructureProvider
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
import ee.carlrobert.codegpt.util.GitUtil


object InfillRequestUtil {

suspend fun buildInfillRequest(
Expand All @@ -34,7 +36,18 @@ object InfillRequestUtil {
}

if (service<ConfigurationSettings>().state.codeCompletionSettings.contextAwareEnabled) {
getInfillContext(request, caretOffset)?.let { infillRequestBuilder.context(it) }
getInfillContext(request, caretOffset)?.let {
infillRequestBuilder.context(it)
infillRequestBuilder.addRepositoryName(it.getRepoName())
}
}

if (service<ConfigurationSettings>().state.codeCompletionSettings.collectDependencyStructure) {
val psiStructure = PsiStructureProvider().get(listOf(request.file))
if (psiStructure.isNotEmpty()) {
infillRequestBuilder.addDependenciesStructure(psiStructure)
infillRequestBuilder.addRepositoryName(psiStructure.first().repositoryName)
}
}

return infillRequestBuilder.build()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package ee.carlrobert.codegpt.codecompletions.psi.structure

import ee.carlrobert.codegpt.codecompletions.psi.structure.models.*
import org.jetbrains.kotlin.utils.addToStdlib.ifNotEmpty

object ClassStructureSerializer {

private const val INDENTATION = " "

/**
* The original class:
* package org.example.package1
*
* import org.example.package2.ClassInPackage2
*
* class ClassInPackage1 {
*
* fun someMethod1(classInPackage2: ClassInPackage2): String = classInPackage2.someMethod2()
* }
*
* Serialized representation of the structure:
* package org.example.package1
*
* class ClassInPackage1 {
* fun someMethod1(classInPackage2: org.example.package2.ClassInPackage2): String
* }
*/
fun serialize(classStructure: ClassStructure): String =
serializeInternal(classStructure)

private fun serializeInternal(classStructure: ClassStructure, level: Int = 1): String {
val currentBodyIndention = INDENTATION.repeat(level)
val currentClassIndention = INDENTATION.repeat(level - 1)

val modifiers = classStructure.modifierList.ifNotEmpty { joinToString(" ", postfix = " ") }.orEmpty()
val classType = classStructure.classType.name.lowercase()
val className = classStructure.simpleName.value
val supertypes = classStructure.supertypes.joinToString(", ") { it.value }
val packageName = classStructure.packageName

val primaryConstructor = classStructure.constructors.firstOrNull()?.let { serializePrimaryConstructor(it) }
val secondaryConstructors = classStructure.constructors
.drop(1)
.ifNotEmpty {
joinToString("\n$currentBodyIndention", prefix = currentBodyIndention) {
serializeSecondaryConstructor(it)
}
}
.orEmpty()

val fields = classStructure.fields
.ifNotEmpty {
joinToString(
"\n$currentBodyIndention",
prefix = currentBodyIndention
) { serializeField(it) }
}
.orEmpty()

val methods = classStructure.methods
.ifNotEmpty {
joinToString(
"\n$currentBodyIndention",
prefix = currentBodyIndention
) { serializeMethod(it) }
}
.orEmpty()

val enumEntries = classStructure.enumEntries
.ifNotEmpty {
joinToString(",\n$currentBodyIndention", prefix = currentBodyIndention) { it.value }
}
.orEmpty()

val innerClasses = classStructure.classes
.ifNotEmpty {
joinToString("\n\n") {
serializeInternal(it, level + 1)
}
}
.orEmpty()

return buildString {
if (level == 1) {
append("package ${packageName.ifEmpty { "Unknown" }}\n\n")
}

if (classStructure.classType == ClassType.COMPANION_OBJECT) {
append("$currentClassIndention${modifiers} object")
} else {
append("$currentClassIndention$modifiers$classType $className")
}

if (primaryConstructor != null) {
append("($primaryConstructor)")
}

if (supertypes.isNotEmpty()) {
append(" : $supertypes")
}
append(" {\n")

if (classStructure.classType == ClassType.ENUM) {
append("$enumEntries\n")
}

if (secondaryConstructors.isNotEmpty()) {
append("$currentBodyIndention$secondaryConstructors\n")
}

if (fields.isNotEmpty()) {
append("$fields\n")
}

if (methods.isNotEmpty()) {
append("$methods\n")
}

if (innerClasses.isNotEmpty()) {
append("$innerClasses\n")
}

append("$currentClassIndention}")
}
}

private fun serializePrimaryConstructor(constructor: ConstructorStructure): String {
val parameters = constructor.parameters.joinToString(", ") { serializeParameter(it) }
return parameters
}

private fun serializeSecondaryConstructor(constructor: ConstructorStructure): String {
val parameters = constructor.parameters.joinToString(", ") { serializeParameter(it) }
val modifiers = constructor.modifiers.joinToString(" ")
return "$modifiers constructor($parameters)"
}

private fun serializeField(field: FieldStructure): String {
val modifiers = field.modifiers.joinToString(" ")
val name = field.name
val type = field.type.value
return "$modifiers $name: $type"
}

private fun serializeMethod(method: MethodStructure): String {
val modifiers = method.modifiers.ifNotEmpty { joinToString(" ", postfix = " ") }.orEmpty()
val name = method.name
val returnType = method.returnType.value
val parameters = method.parameters.joinToString(", ") { serializeParameter(it) }
return "${modifiers}fun $name($parameters): $returnType"
}

private fun serializeParameter(parameter: ParameterInfo): String {
val name = parameter.name
val type = parameter.type.value
val modifiers = parameter.modifiers.joinToString(" ")
return "$modifiers $name: $type"
}
}
Loading

0 comments on commit 72f31bf

Please sign in to comment.