diff --git a/rewrite-python/src/main/java/org/openrewrite/python/AddDependency.java b/rewrite-python/src/main/java/org/openrewrite/python/AddDependency.java index 5ab5a380a0..1cf3472b23 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/AddDependency.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/AddDependency.java @@ -19,22 +19,19 @@ import lombok.Value; import org.jspecify.annotations.Nullable; import org.openrewrite.*; -import org.openrewrite.marker.Markers; import org.openrewrite.python.internal.PyProjectHelper; import org.openrewrite.python.internal.PythonDependencyExecutionContextView; -import org.openrewrite.python.marker.PythonResolutionResult; -import org.openrewrite.toml.TomlIsoVisitor; -import org.openrewrite.toml.tree.Space; +import org.openrewrite.python.trait.PythonDependencyFile; import org.openrewrite.toml.tree.Toml; -import org.openrewrite.toml.tree.TomlRightPadded; -import org.openrewrite.toml.tree.TomlType; -import java.util.*; - -import static org.openrewrite.Tree.randomId; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; /** - * Add a dependency to the {@code [project].dependencies} array in pyproject.toml. + * Add a dependency to a Python project. Supports {@code pyproject.toml} + * (with scope and group targeting), {@code requirements.txt}, and {@code Pipfile}. * When uv is available, the uv.lock file is regenerated to reflect the change. */ @EqualsAndHashCode(callSuper = false) @@ -54,7 +51,9 @@ public class AddDependency extends ScanningRecipe { String version; @Option(displayName = "Scope", - description = "The dependency scope to add to. Defaults to `project.dependencies`.", + description = "The dependency scope to add to. For pyproject.toml this targets a specific TOML section. " + + "For requirements files, `null` matches all files, empty string matches only `requirements.txt`, " + + "and a value like `dev` matches `requirements-dev.txt`. Defaults to `project.dependencies`.", valid = {"project.dependencies", "project.optional-dependencies", "dependency-groups", "tool.uv.constraint-dependencies", "tool.uv.override-dependencies"}, example = "project.dependencies", @@ -90,7 +89,8 @@ public String getInstanceNameSuffix() { @Override public String getDescription() { - return "Add a dependency to the `[project].dependencies` array in `pyproject.toml`. " + + return "Add a dependency to a Python project. Supports `pyproject.toml` " + + "(with scope/group targeting), `requirements.txt`, and `Pipfile`. " + "When `uv` is available, the `uv.lock` file is regenerated."; } @@ -105,141 +105,67 @@ public Accumulator getInitialValue(ExecutionContext ctx) { @Override public TreeVisitor getScanner(Accumulator acc) { - return new TomlIsoVisitor() { + return new TreeVisitor() { @Override - public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) { - String sourcePath = document.getSourcePath().toString(); - - if (sourcePath.endsWith("uv.lock")) { - PythonDependencyExecutionContextView.view(ctx).getExistingLockContents().put( - PyProjectHelper.correspondingPyprojectPath(sourcePath), - document.printAll()); - return document; + public @Nullable Tree preVisit(Tree tree, ExecutionContext ctx) { + if (!(tree instanceof SourceFile)) { + return tree; } - - if (!sourcePath.endsWith("pyproject.toml")) { - return document; + stopAfterPreVisit(); + SourceFile sourceFile = (SourceFile) tree; + if (tree instanceof Toml.Document && sourceFile.getSourcePath().toString().endsWith("uv.lock")) { + PythonDependencyExecutionContextView.view(ctx).getExistingLockContents().put( + PyProjectHelper.correspondingPyprojectPath(sourceFile.getSourcePath().toString()), + ((Toml.Document) tree).printAll()); + return tree; } - Optional resolution = document.getMarkers() - .findFirst(PythonResolutionResult.class); - if (!resolution.isPresent()) { - return document; + PythonDependencyFile trait = new PythonDependencyFile.Matcher().get(getCursor()).orElse(null); + if (trait == null) { + return tree; } - - PythonResolutionResult marker = resolution.get(); - - // Check if the dependency already exists in the target scope - if (PyProjectHelper.findDependencyInScope(marker, packageName, scope, groupName) != null) { - return document; + if (PyProjectHelper.findDependencyInScope(trait.getMarker(), packageName, scope, groupName) != null) { + return tree; } - - acc.projectsToUpdate.add(sourcePath); - return document; + acc.projectsToUpdate.add(sourceFile.getSourcePath().toString()); + return tree; } }; } @Override public TreeVisitor getVisitor(Accumulator acc) { - return new TomlIsoVisitor() { + return new TreeVisitor() { @Override - public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) { - String sourcePath = document.getSourcePath().toString(); - - if (sourcePath.endsWith("pyproject.toml") && acc.projectsToUpdate.contains(sourcePath)) { - return addDependencyToPyproject(document, ctx, acc); + public @Nullable Tree preVisit(Tree tree, ExecutionContext ctx) { + if (!(tree instanceof SourceFile)) { + return tree; + } + stopAfterPreVisit(); + SourceFile sourceFile = (SourceFile) tree; + String sourcePath = sourceFile.getSourcePath().toString(); + + if (acc.projectsToUpdate.contains(sourcePath)) { + PythonDependencyFile trait = new PythonDependencyFile.Matcher().get(getCursor()).orElse(null); + if (trait != null) { + String ver = version != null ? version : ""; + Map additions = Collections.singletonMap(packageName, ver); + PythonDependencyFile updated = trait.withAddedDependencies(additions, scope, groupName); + if (updated.getTree() != tree) { + return updated.afterModification(ctx); + } + } } - if (sourcePath.endsWith("uv.lock")) { - Toml.Document updatedLock = PyProjectHelper.maybeUpdateUvLock(document, ctx); + if (tree instanceof Toml.Document && sourcePath.endsWith("uv.lock")) { + Toml.Document updatedLock = PyProjectHelper.maybeUpdateUvLock((Toml.Document) tree, ctx); if (updatedLock != null) { return updatedLock; } } - return document; + return tree; } }; } - private Toml.Document addDependencyToPyproject(Toml.Document document, ExecutionContext ctx, Accumulator acc) { - String pep508 = version != null ? packageName + PyProjectHelper.normalizeVersionConstraint(version) : packageName; - - Toml.Document updated = (Toml.Document) new TomlIsoVisitor() { - @Override - public Toml.Array visitArray(Toml.Array array, ExecutionContext ctx) { - Toml.Array a = super.visitArray(array, ctx); - - if (!PyProjectHelper.isInsideDependencyArray(getCursor(), scope, groupName)) { - return a; - } - - Toml.Literal newLiteral = new Toml.Literal( - randomId(), - Space.EMPTY, - Markers.EMPTY, - TomlType.Primitive.String, - "\"" + pep508 + "\"", - pep508 - ); - - List> existingPadded = a.getPadding().getValues(); - List> newPadded = new ArrayList<>(); - - // An empty TOML array [] is represented as a single Toml.Empty element - boolean isEmpty = existingPadded.size() == 1 && - existingPadded.get(0).getElement() instanceof Toml.Empty; - if (existingPadded.isEmpty() || isEmpty) { - newPadded.add(new TomlRightPadded<>(newLiteral, Space.EMPTY, Markers.EMPTY)); - } else { - // Check if the last element is Toml.Empty (trailing comma marker) - TomlRightPadded lastPadded = existingPadded.get(existingPadded.size() - 1); - boolean hasTrailingComma = lastPadded.getElement() instanceof Toml.Empty; - - if (hasTrailingComma) { - // Insert before the Empty element. The Empty's position - // stores the whitespace before ']'. - // Find the last real element to copy its prefix formatting - int lastRealIdx = existingPadded.size() - 2; - Toml lastRealElement = existingPadded.get(lastRealIdx).getElement(); - Toml.Literal formattedLiteral = newLiteral.withPrefix(lastRealElement.getPrefix()); - - // Copy all existing elements up to (not including) the Empty - for (int i = 0; i <= lastRealIdx; i++) { - newPadded.add(existingPadded.get(i)); - } - // Add new literal with empty after (comma added by printer) - newPadded.add(new TomlRightPadded<>(formattedLiteral, Space.EMPTY, Markers.EMPTY)); - // Keep the Empty element for trailing comma + closing bracket whitespace - newPadded.add(lastPadded); - } else { - // No trailing comma — the last real element's after has the space before ']' - Toml lastElement = lastPadded.getElement(); - // For multi-line arrays, use same prefix; for inline, use single space - Space newPrefix = lastElement.getPrefix().getWhitespace().contains("\n") - ? lastElement.getPrefix() - : Space.SINGLE_SPACE; - Toml.Literal formattedLiteral = newLiteral.withPrefix(newPrefix); - - // Copy all existing elements but set last one's after to empty - for (int i = 0; i < existingPadded.size() - 1; i++) { - newPadded.add(existingPadded.get(i)); - } - newPadded.add(lastPadded.withAfter(Space.EMPTY)); - // New element gets the after from the old last element - newPadded.add(new TomlRightPadded<>(formattedLiteral, lastPadded.getAfter(), Markers.EMPTY)); - } - } - - return a.getPadding().withValues(newPadded); - } - }.visitNonNull(document, ctx); - - if (updated != document) { - updated = PyProjectHelper.regenerateLockAndRefreshMarker(updated, ctx); - } - - return updated; - } - } diff --git a/rewrite-python/src/main/java/org/openrewrite/python/Assertions.java b/rewrite-python/src/main/java/org/openrewrite/python/Assertions.java index a785c7d463..2df0a0ab8b 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/Assertions.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/Assertions.java @@ -237,6 +237,45 @@ public static SourceSpecs setupCfg(@Nullable String before, return text; } + public static SourceSpecs pipfile(@Language("toml") @Nullable String before) { + return pipfile(before, s -> { + }); + } + + public static SourceSpecs pipfile(@Language("toml") @Nullable String before, + Consumer> spec) { + SourceSpec toml = new SourceSpec<>( + Toml.Document.class, null, PipfileParser.builder(), before, + SourceSpec.ValidateSource.noop, + ctx -> { + } + ); + toml.path("Pipfile"); + spec.accept(toml); + return toml; + } + + public static SourceSpecs pipfile(@Language("toml") @Nullable String before, + @Language("toml") @Nullable String after) { + return pipfile(before, after, s -> { + }); + } + + public static SourceSpecs pipfile(@Language("toml") @Nullable String before, + @Language("toml") @Nullable String after, + Consumer> spec) { + SourceSpec toml = new SourceSpec<>( + Toml.Document.class, null, PipfileParser.builder(), before, + SourceSpec.ValidateSource.noop, + ctx -> { + } + ); + toml.path("Pipfile"); + toml.after(s -> after); + spec.accept(toml); + return toml; + } + public static SourceSpecs python(@Language("py") @Nullable String before) { return python(before, s -> { }); diff --git a/rewrite-python/src/main/java/org/openrewrite/python/ChangeDependency.java b/rewrite-python/src/main/java/org/openrewrite/python/ChangeDependency.java index d945fb7978..a9f8f9952e 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/ChangeDependency.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/ChangeDependency.java @@ -18,22 +18,18 @@ import lombok.EqualsAndHashCode; import lombok.Value; import org.jspecify.annotations.Nullable; -import org.openrewrite.ExecutionContext; -import org.openrewrite.Option; -import org.openrewrite.ScanningRecipe; -import org.openrewrite.TreeVisitor; +import org.openrewrite.*; import org.openrewrite.python.internal.PyProjectHelper; import org.openrewrite.python.internal.PythonDependencyExecutionContextView; -import org.openrewrite.python.marker.PythonResolutionResult; -import org.openrewrite.toml.TomlIsoVisitor; +import org.openrewrite.python.trait.PythonDependencyFile; import org.openrewrite.toml.tree.Toml; -import org.openrewrite.toml.tree.TomlType; -import java.util.*; +import java.util.HashSet; +import java.util.Set; /** - * Change a dependency to a different package in pyproject.toml. - * Searches all dependency arrays in the document (no scope restriction). + * Change a dependency to a different package. Supports {@code pyproject.toml}, + * {@code requirements.txt}, and {@code Pipfile}. Searches all dependency scopes. * When uv is available, the uv.lock file is regenerated to reflect the change. */ @EqualsAndHashCode(callSuper = false) @@ -69,8 +65,9 @@ public String getInstanceNameSuffix() { @Override public String getDescription() { - return "Change a dependency to a different package in `pyproject.toml`. " + - "Searches all dependency arrays. When `uv` is available, the `uv.lock` file is regenerated."; + return "Change a dependency to a different package. Supports `pyproject.toml`, " + + "`requirements.txt`, and `Pipfile`. Searches all dependency scopes. " + + "When `uv` is available, the `uv.lock` file is regenerated."; } static class Accumulator { @@ -84,130 +81,63 @@ public Accumulator getInitialValue(ExecutionContext ctx) { @Override public TreeVisitor getScanner(Accumulator acc) { - return new TomlIsoVisitor() { + return new TreeVisitor() { @Override - public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) { - String sourcePath = document.getSourcePath().toString(); - - if (sourcePath.endsWith("uv.lock")) { - PythonDependencyExecutionContextView.view(ctx).getExistingLockContents().put( - PyProjectHelper.correspondingPyprojectPath(sourcePath), - document.printAll()); - return document; + public @Nullable Tree preVisit(Tree tree, ExecutionContext ctx) { + if (!(tree instanceof SourceFile)) { + return tree; } - - if (!sourcePath.endsWith("pyproject.toml")) { - return document; + stopAfterPreVisit(); + SourceFile sourceFile = (SourceFile) tree; + if (tree instanceof Toml.Document && sourceFile.getSourcePath().toString().endsWith("uv.lock")) { + PythonDependencyExecutionContextView.view(ctx).getExistingLockContents().put( + PyProjectHelper.correspondingPyprojectPath(sourceFile.getSourcePath().toString()), + ((Toml.Document) tree).printAll()); + return tree; } - Optional resolution = document.getMarkers() - .findFirst(PythonResolutionResult.class); - if (!resolution.isPresent()) { - return document; + PythonDependencyFile trait = new PythonDependencyFile.Matcher().get(getCursor()).orElse(null); + if (trait == null) { + return tree; } - - PythonResolutionResult marker = resolution.get(); - if (marker.findDependencyInAnyScope(oldPackageName) != null) { - acc.projectsToUpdate.add(sourcePath); + if (trait.getMarker().findDependencyInAnyScope(oldPackageName) != null) { + acc.projectsToUpdate.add(sourceFile.getSourcePath().toString()); } - return document; + return tree; } }; } @Override public TreeVisitor getVisitor(Accumulator acc) { - return new TomlIsoVisitor() { + return new TreeVisitor() { @Override - public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) { - String sourcePath = document.getSourcePath().toString(); - - if (sourcePath.endsWith("pyproject.toml") && acc.projectsToUpdate.contains(sourcePath)) { - return changeDependencyInPyproject(document, ctx, acc); + public @Nullable Tree preVisit(Tree tree, ExecutionContext ctx) { + if (!(tree instanceof SourceFile)) { + return tree; + } + stopAfterPreVisit(); + SourceFile sourceFile = (SourceFile) tree; + String sourcePath = sourceFile.getSourcePath().toString(); + + if (acc.projectsToUpdate.contains(sourcePath)) { + PythonDependencyFile trait = new PythonDependencyFile.Matcher().get(getCursor()).orElse(null); + if (trait != null) { + PythonDependencyFile updated = trait.withChangedDependency(oldPackageName, newPackageName, newVersion); + if (updated.getTree() != tree) { + return updated.afterModification(ctx); + } + } } - if (sourcePath.endsWith("uv.lock")) { - Toml.Document updatedLock = PyProjectHelper.maybeUpdateUvLock(document, ctx); + if (tree instanceof Toml.Document && sourcePath.endsWith("uv.lock")) { + Toml.Document updatedLock = PyProjectHelper.maybeUpdateUvLock((Toml.Document) tree, ctx); if (updatedLock != null) { return updatedLock; } } - return document; + return tree; } }; } - - private Toml.Document changeDependencyInPyproject(Toml.Document document, ExecutionContext ctx, Accumulator acc) { - String normalizedOld = PythonResolutionResult.normalizeName(oldPackageName); - - Toml.Document updated = (Toml.Document) new TomlIsoVisitor() { - @Override - public Toml.Literal visitLiteral(Toml.Literal literal, ExecutionContext ctx) { - Toml.Literal l = super.visitLiteral(literal, ctx); - if (l.getType() != TomlType.Primitive.String) { - return l; - } - - Object val = l.getValue(); - if (!(val instanceof String)) { - return l; - } - - String spec = (String) val; - String depName = PyProjectHelper.extractPackageName(spec); - if (depName == null || !PythonResolutionResult.normalizeName(depName).equals(normalizedOld)) { - return l; - } - - // Build new PEP 508 string - String extras = UpgradeDependencyVersion.extractExtras(spec); - String marker = UpgradeDependencyVersion.extractMarker(spec); - - StringBuilder sb = new StringBuilder(newPackageName); - if (extras != null) { - sb.append('[').append(extras).append(']'); - } - if (newVersion != null) { - sb.append(PyProjectHelper.normalizeVersionConstraint(newVersion)); - } else { - // Preserve the original version constraint - String originalVersion = extractVersionConstraint(spec, depName); - if (originalVersion != null) { - sb.append(originalVersion); - } - } - if (marker != null) { - sb.append("; ").append(marker); - } - - String newSpec = sb.toString(); - return l.withSource("\"" + newSpec + "\"").withValue(newSpec); - } - }.visitNonNull(document, ctx); - - if (updated != document) { - updated = PyProjectHelper.regenerateLockAndRefreshMarker(updated, ctx); - } - - return updated; - } - - /** - * Extract the version constraint portion from a PEP 508 spec. - * Returns the version constraint (e.g. ">=2.28.0") or null if there is none. - */ - private static @Nullable String extractVersionConstraint(String spec, String name) { - String remaining = spec.substring(name.length()).trim(); - // Skip extras [...] - if (remaining.startsWith("[")) { - int end = remaining.indexOf(']'); - if (end >= 0) { - remaining = remaining.substring(end + 1).trim(); - } - } - // Extract version constraint up to marker - int markerIdx = remaining.indexOf(';'); - String versionPart = markerIdx >= 0 ? remaining.substring(0, markerIdx).trim() : remaining.trim(); - return versionPart.isEmpty() ? null : versionPart; - } } diff --git a/rewrite-python/src/main/java/org/openrewrite/python/PipfileParser.java b/rewrite-python/src/main/java/org/openrewrite/python/PipfileParser.java new file mode 100644 index 0000000000..5d3e11df39 --- /dev/null +++ b/rewrite-python/src/main/java/org/openrewrite/python/PipfileParser.java @@ -0,0 +1,192 @@ +/* + * Copyright 2026 the original author or authors. + * + * Moderne Proprietary. Only for use by Moderne customers under the terms of a commercial contract. + */ +package org.openrewrite.python; + +import org.jspecify.annotations.Nullable; +import org.openrewrite.ExecutionContext; +import org.openrewrite.Parser; +import org.openrewrite.SourceFile; +import org.openrewrite.python.internal.PyProjectHelper; +import org.openrewrite.python.marker.PythonResolutionResult; +import org.openrewrite.python.marker.PythonResolutionResult.Dependency; +import org.openrewrite.python.marker.PythonResolutionResult.PackageManager; +import org.openrewrite.toml.TomlParser; +import org.openrewrite.toml.tree.Toml; + +import java.nio.file.Path; +import java.util.*; +import java.util.stream.Stream; + +import static org.openrewrite.Tree.randomId; + +/** + * Parser for Pipfile files that delegates to {@link TomlParser} and attaches a + * {@link PythonResolutionResult} marker with dependency metadata. + */ +public class PipfileParser implements Parser { + + private final TomlParser tomlParser = new TomlParser(); + + @Override + public Stream parseInputs(Iterable sources, @Nullable Path relativeTo, ExecutionContext ctx) { + return tomlParser.parseInputs(sources, relativeTo, ctx).map(sf -> { + if (!(sf instanceof Toml.Document)) { + return sf; + } + Toml.Document doc = (Toml.Document) sf; + PythonResolutionResult marker = createMarker(doc); + if (marker == null) { + return sf; + } + return doc.withMarkers(doc.getMarkers().addIfAbsent(marker)); + }); + } + + static @Nullable PythonResolutionResult createMarker(Toml.Document doc) { + Map tables = indexTables(doc); + + Toml.Table packagesTable = tables.get("packages"); + Toml.Table devPackagesTable = tables.get("dev-packages"); + + // A Pipfile should have at least one dependency section + if (packagesTable == null && devPackagesTable == null) { + return null; + } + + List dependencies = parseDependencyTable(packagesTable); + + Map> optionalDependencies = new LinkedHashMap<>(); + List devDeps = parseDependencyTable(devPackagesTable); + if (!devDeps.isEmpty()) { + optionalDependencies.put("dev-packages", devDeps); + } + + Toml.Table requiresTable = tables.get("requires"); + String requiresPython = requiresTable != null ? getStringValue(requiresTable, "python_version") : null; + + return new PythonResolutionResult( + randomId(), + null, + null, + null, + null, + doc.getSourcePath().toString(), + requiresPython, + null, + Collections.emptyList(), + dependencies, + optionalDependencies, + Collections.emptyMap(), + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList(), + PackageManager.Pipenv, + null + ); + } + + private static List parseDependencyTable(Toml.@Nullable Table table) { + if (table == null) { + return Collections.emptyList(); + } + List deps = new ArrayList<>(); + for (Toml value : table.getValues()) { + if (!(value instanceof Toml.KeyValue)) { + continue; + } + Toml.KeyValue kv = (Toml.KeyValue) value; + if (!(kv.getKey() instanceof Toml.Identifier)) { + continue; + } + String name = ((Toml.Identifier) kv.getKey()).getName(); + String versionConstraint = extractVersion(kv.getValue()); + if ("*".equals(versionConstraint)) { + versionConstraint = null; + } + deps.add(new Dependency(name, versionConstraint, null, null, null)); + } + return deps; + } + + private static @Nullable String extractVersion(Toml value) { + if (value instanceof Toml.Literal) { + Object v = ((Toml.Literal) value).getValue(); + return v instanceof String ? (String) v : null; + } + if (value instanceof Toml.Table) { + // Inline table: {version = ">=3.2", ...} + for (Toml inner : ((Toml.Table) value).getValues()) { + if (inner instanceof Toml.KeyValue) { + Toml.KeyValue innerKv = (Toml.KeyValue) inner; + if (innerKv.getKey() instanceof Toml.Identifier && + "version".equals(((Toml.Identifier) innerKv.getKey()).getName())) { + return extractVersion(innerKv.getValue()); + } + } + } + } + return null; + } + + private static @Nullable String getStringValue(Toml.Table table, String key) { + for (Toml value : table.getValues()) { + if (value instanceof Toml.KeyValue) { + Toml.KeyValue kv = (Toml.KeyValue) value; + if (kv.getKey() instanceof Toml.Identifier && + key.equals(((Toml.Identifier) kv.getKey()).getName()) && + kv.getValue() instanceof Toml.Literal) { + Object v = ((Toml.Literal) kv.getValue()).getValue(); + return v instanceof String ? (String) v : null; + } + } + } + return null; + } + + private static Map indexTables(Toml.Document doc) { + Map tables = new LinkedHashMap<>(); + for (Toml value : doc.getValues()) { + if (value instanceof Toml.Table) { + Toml.Table table = (Toml.Table) value; + if (table.getName() != null) { + tables.put(table.getName().getName(), table); + } + } + } + return tables; + } + + @Override + public boolean accept(Path path) { + return "Pipfile".equals(path.getFileName().toString()); + } + + @Override + public Path sourcePathFromSourceText(Path prefix, String sourceCode) { + return prefix.resolve("Pipfile"); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder extends Parser.Builder { + + Builder() { + super(Toml.Document.class); + } + + @Override + public PipfileParser build() { + return new PipfileParser(); + } + + @Override + public String getDslName() { + return "Pipfile"; + } + } +} diff --git a/rewrite-python/src/main/java/org/openrewrite/python/RemoveDependency.java b/rewrite-python/src/main/java/org/openrewrite/python/RemoveDependency.java index 0b62c3f411..8fd8e46f1b 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/RemoveDependency.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/RemoveDependency.java @@ -21,16 +21,16 @@ import org.openrewrite.*; import org.openrewrite.python.internal.PyProjectHelper; import org.openrewrite.python.internal.PythonDependencyExecutionContextView; -import org.openrewrite.python.marker.PythonResolutionResult; -import org.openrewrite.toml.TomlIsoVisitor; -import org.openrewrite.toml.tree.Space; +import org.openrewrite.python.trait.PythonDependencyFile; import org.openrewrite.toml.tree.Toml; -import org.openrewrite.toml.tree.TomlRightPadded; -import java.util.*; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; /** - * Remove a dependency from the {@code [project].dependencies} array in pyproject.toml. + * Remove a dependency from a Python project. Supports {@code pyproject.toml} + * (with scope and group targeting), {@code requirements.txt}, and {@code Pipfile}. * When uv is available, the uv.lock file is regenerated to reflect the change. */ @EqualsAndHashCode(callSuper = false) @@ -79,7 +79,8 @@ public String getInstanceNameSuffix() { @Override public String getDescription() { - return "Remove a dependency from the `[project].dependencies` array in `pyproject.toml`. " + + return "Remove a dependency from a Python project. Supports `pyproject.toml` " + + "(with scope/group targeting), `requirements.txt`, and `Pipfile`. " + "When `uv` is available, the `uv.lock` file is regenerated."; } @@ -94,124 +95,66 @@ public Accumulator getInitialValue(ExecutionContext ctx) { @Override public TreeVisitor getScanner(Accumulator acc) { - return new TomlIsoVisitor() { + return new TreeVisitor() { @Override - public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) { - String sourcePath = document.getSourcePath().toString(); - - if (sourcePath.endsWith("uv.lock")) { - PythonDependencyExecutionContextView.view(ctx).getExistingLockContents().put( - PyProjectHelper.correspondingPyprojectPath(sourcePath), - document.printAll()); - return document; + public @Nullable Tree preVisit(Tree tree, ExecutionContext ctx) { + if (!(tree instanceof SourceFile)) { + return tree; } - - if (!sourcePath.endsWith("pyproject.toml")) { - return document; + stopAfterPreVisit(); + SourceFile sourceFile = (SourceFile) tree; + if (tree instanceof Toml.Document && sourceFile.getSourcePath().toString().endsWith("uv.lock")) { + PythonDependencyExecutionContextView.view(ctx).getExistingLockContents().put( + PyProjectHelper.correspondingPyprojectPath(sourceFile.getSourcePath().toString()), + ((Toml.Document) tree).printAll()); + return tree; } - Optional resolution = document.getMarkers() - .findFirst(PythonResolutionResult.class); - if (!resolution.isPresent()) { - return document; + PythonDependencyFile trait = new PythonDependencyFile.Matcher().get(getCursor()).orElse(null); + if (trait == null) { + return tree; } - - PythonResolutionResult marker = resolution.get(); - - // Check if the dependency exists in the target scope - if (PyProjectHelper.findDependencyInScope(marker, packageName, scope, groupName) == null) { - return document; + if (PyProjectHelper.findDependencyInScope(trait.getMarker(), packageName, scope, groupName) == null) { + return tree; } - - acc.projectsToUpdate.add(sourcePath); - return document; + acc.projectsToUpdate.add(sourceFile.getSourcePath().toString()); + return tree; } }; } @Override public TreeVisitor getVisitor(Accumulator acc) { - return new TomlIsoVisitor() { - @Override - public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) { - String sourcePath = document.getSourcePath().toString(); - - if (sourcePath.endsWith("pyproject.toml") && acc.projectsToUpdate.contains(sourcePath)) { - return removeDependencyFromPyproject(document, ctx, acc); - } - - if (sourcePath.endsWith("uv.lock")) { - Toml.Document updatedLock = PyProjectHelper.maybeUpdateUvLock(document, ctx); - if (updatedLock != null) { - return updatedLock; - } - } - - return document; - } - }; - } - - private Toml.Document removeDependencyFromPyproject(Toml.Document document, ExecutionContext ctx, Accumulator acc) { - String normalizedName = PythonResolutionResult.normalizeName(packageName); - - Toml.Document updated = (Toml.Document) new TomlIsoVisitor() { + return new TreeVisitor() { @Override - public Toml.Array visitArray(Toml.Array array, ExecutionContext ctx) { - Toml.Array a = super.visitArray(array, ctx); - - if (!PyProjectHelper.isInsideDependencyArray(getCursor(), scope, groupName)) { - return a; + public @Nullable Tree preVisit(Tree tree, ExecutionContext ctx) { + if (!(tree instanceof SourceFile)) { + return tree; } - - // Find and remove the matching dependency - List> existingPadded = a.getPadding().getValues(); - List> newPadded = new ArrayList<>(); - boolean found = false; - int removedIdx = -1; - - for (int i = 0; i < existingPadded.size(); i++) { - TomlRightPadded padded = existingPadded.get(i); - Toml element = padded.getElement(); - - if (!found && element instanceof Toml.Literal) { - Object val = ((Toml.Literal) element).getValue(); - if (val instanceof String) { - String depName = PyProjectHelper.extractPackageName((String) val); - if (depName != null && PythonResolutionResult.normalizeName(depName).equals(normalizedName)) { - found = true; - removedIdx = i; - continue; - } + stopAfterPreVisit(); + SourceFile sourceFile = (SourceFile) tree; + String sourcePath = sourceFile.getSourcePath().toString(); + + if (acc.projectsToUpdate.contains(sourcePath)) { + PythonDependencyFile trait = new PythonDependencyFile.Matcher().get(getCursor()).orElse(null); + if (trait != null) { + PythonDependencyFile updated = trait.withRemovedDependencies( + Collections.singleton(packageName), scope, groupName); + if (updated.getTree() != tree) { + return updated.afterModification(ctx); } } - - newPadded.add(padded); } - if (!found) { - return a; - } - - // If the removed element was the first one, the next element - // may have a space prefix from comma formatting. Transfer the - // removed element's prefix to the first remaining real element. - if (removedIdx == 0 && !newPadded.isEmpty()) { - TomlRightPadded first = newPadded.get(0); - if (!(first.getElement() instanceof Toml.Empty)) { - Space originalPrefix = existingPadded.get(removedIdx).getElement().getPrefix(); - newPadded.set(0, first.map(el -> el.withPrefix(originalPrefix))); + if (tree instanceof Toml.Document && sourcePath.endsWith("uv.lock")) { + Toml.Document updatedLock = PyProjectHelper.maybeUpdateUvLock((Toml.Document) tree, ctx); + if (updatedLock != null) { + return updatedLock; } } - return a.getPadding().withValues(newPadded); + return tree; } - }.visitNonNull(document, ctx); - - if (updated != document) { - updated = PyProjectHelper.regenerateLockAndRefreshMarker(updated, ctx); - } - - return updated; + }; } } diff --git a/rewrite-python/src/main/java/org/openrewrite/python/UpgradeDependencyVersion.java b/rewrite-python/src/main/java/org/openrewrite/python/UpgradeDependencyVersion.java index 159500b96f..7bec9b8c1e 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/UpgradeDependencyVersion.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/UpgradeDependencyVersion.java @@ -22,14 +22,17 @@ import org.openrewrite.python.internal.PyProjectHelper; import org.openrewrite.python.internal.PythonDependencyExecutionContextView; import org.openrewrite.python.marker.PythonResolutionResult; -import org.openrewrite.toml.TomlIsoVisitor; +import org.openrewrite.python.trait.PythonDependencyFile; import org.openrewrite.toml.tree.Toml; -import org.openrewrite.toml.tree.TomlType; -import java.util.*; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; /** - * Upgrade the version constraint for a dependency in {@code [project].dependencies} in pyproject.toml. + * Upgrade the version constraint for a dependency. Supports {@code pyproject.toml} + * (with scope and group targeting), {@code requirements.txt}, and {@code Pipfile}. * When uv is available, the uv.lock file is regenerated to reflect the change. */ @EqualsAndHashCode(callSuper = false) @@ -83,7 +86,8 @@ public String getInstanceNameSuffix() { @Override public String getDescription() { - return "Upgrade the version constraint for a dependency in `[project].dependencies` in `pyproject.toml`. " + + return "Upgrade the version constraint for a dependency. Supports `pyproject.toml` " + + "(with scope/group targeting), `requirements.txt`, and `Pipfile`. " + "When `uv` is available, the `uv.lock` file is regenerated."; } @@ -98,155 +102,71 @@ public Accumulator getInitialValue(ExecutionContext ctx) { @Override public TreeVisitor getScanner(Accumulator acc) { - return new TomlIsoVisitor() { + return new TreeVisitor() { @Override - public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) { - String sourcePath = document.getSourcePath().toString(); - - if (sourcePath.endsWith("uv.lock")) { - PythonDependencyExecutionContextView.view(ctx).getExistingLockContents().put( - PyProjectHelper.correspondingPyprojectPath(sourcePath), - document.printAll()); - return document; + public @Nullable Tree preVisit(Tree tree, ExecutionContext ctx) { + if (!(tree instanceof SourceFile)) { + return tree; } - - if (!sourcePath.endsWith("pyproject.toml")) { - return document; + stopAfterPreVisit(); + SourceFile sourceFile = (SourceFile) tree; + if (tree instanceof Toml.Document && sourceFile.getSourcePath().toString().endsWith("uv.lock")) { + PythonDependencyExecutionContextView.view(ctx).getExistingLockContents().put( + PyProjectHelper.correspondingPyprojectPath(sourceFile.getSourcePath().toString()), + ((Toml.Document) tree).printAll()); + return tree; } - Optional resolution = document.getMarkers() - .findFirst(PythonResolutionResult.class); - if (!resolution.isPresent()) { - return document; + PythonDependencyFile trait = new PythonDependencyFile.Matcher().get(getCursor()).orElse(null); + if (trait == null) { + return tree; } - - PythonResolutionResult marker = resolution.get(); - - // Check if the dependency exists in the target scope and has a different version PythonResolutionResult.Dependency dep = PyProjectHelper.findDependencyInScope( - marker, packageName, scope, groupName); + trait.getMarker(), packageName, scope, groupName); if (dep == null) { - return document; + return tree; } - - // Skip if the version constraint already matches if (PyProjectHelper.normalizeVersionConstraint(newVersion).equals(dep.getVersionConstraint())) { - return document; + return tree; } - - acc.projectsToUpdate.add(sourcePath); - return document; + acc.projectsToUpdate.add(sourceFile.getSourcePath().toString()); + return tree; } }; } @Override public TreeVisitor getVisitor(Accumulator acc) { - return new TomlIsoVisitor() { + return new TreeVisitor() { @Override - public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) { - String sourcePath = document.getSourcePath().toString(); - - if (sourcePath.endsWith("pyproject.toml") && acc.projectsToUpdate.contains(sourcePath)) { - return changeVersionInPyproject(document, ctx, acc); + public @Nullable Tree preVisit(Tree tree, ExecutionContext ctx) { + if (!(tree instanceof SourceFile)) { + return tree; + } + stopAfterPreVisit(); + SourceFile sourceFile = (SourceFile) tree; + String sourcePath = sourceFile.getSourcePath().toString(); + + if (acc.projectsToUpdate.contains(sourcePath)) { + PythonDependencyFile trait = new PythonDependencyFile.Matcher().get(getCursor()).orElse(null); + if (trait != null) { + Map upgrades = Collections.singletonMap( + PythonResolutionResult.normalizeName(packageName), newVersion); + PythonDependencyFile updated = trait.withUpgradedVersions(upgrades, scope, groupName); + if (updated.getTree() != tree) { + return updated.afterModification(ctx); + } + } } - if (sourcePath.endsWith("uv.lock")) { - Toml.Document updatedLock = PyProjectHelper.maybeUpdateUvLock(document, ctx); + if (tree instanceof Toml.Document && sourcePath.endsWith("uv.lock")) { + Toml.Document updatedLock = PyProjectHelper.maybeUpdateUvLock((Toml.Document) tree, ctx); if (updatedLock != null) { return updatedLock; } } - return document; + return tree; } }; } - - private Toml.Document changeVersionInPyproject(Toml.Document document, ExecutionContext ctx, Accumulator acc) { - String normalizedName = PythonResolutionResult.normalizeName(packageName); - - Toml.Document updated = (Toml.Document) new TomlIsoVisitor() { - @Override - public Toml.Literal visitLiteral(Toml.Literal literal, ExecutionContext ctx) { - Toml.Literal l = super.visitLiteral(literal, ctx); - if (l.getType() != TomlType.Primitive.String) { - return l; - } - - Object val = l.getValue(); - if (!(val instanceof String)) { - return l; - } - - // Check if we're inside the target dependency array - if (!isInsideTargetDependencies()) { - return l; - } - - // Check if this literal matches the package we're looking for - String spec = (String) val; - String depName = PyProjectHelper.extractPackageName(spec); - if (depName == null || !PythonResolutionResult.normalizeName(depName).equals(normalizedName)) { - return l; - } - - // Build new PEP 508 string preserving extras and markers - String newSpec = buildNewSpec(spec, depName); - return l.withSource("\"" + newSpec + "\"").withValue(newSpec); - } - - private boolean isInsideTargetDependencies() { - // Walk up the cursor to find the enclosing array, then check scope - Cursor c = getCursor(); - while (c != null) { - if (c.getValue() instanceof Toml.Array) { - return PyProjectHelper.isInsideDependencyArray(c, scope, groupName); - } - c = c.getParent(); - } - return false; - } - - private String buildNewSpec(String oldSpec, String depName) { - // Parse extras and markers from old spec - String extras = extractExtras(oldSpec); - String marker = extractMarker(oldSpec); - - StringBuilder sb = new StringBuilder(depName); - if (extras != null) { - sb.append('[').append(extras).append(']'); - } - sb.append(PyProjectHelper.normalizeVersionConstraint(newVersion)); - if (marker != null) { - sb.append("; ").append(marker); - } - return sb.toString(); - } - }.visitNonNull(document, ctx); - - if (updated != document) { - updated = PyProjectHelper.regenerateLockAndRefreshMarker(updated, ctx); - } - - return updated; - } - - static @Nullable String extractExtras(String pep508Spec) { - int start = pep508Spec.indexOf('['); - int end = pep508Spec.indexOf(']'); - if (start >= 0 && end > start) { - return pep508Spec.substring(start + 1, end); - } - return null; - } - - static @Nullable String extractMarker(String pep508Spec) { - int idx = pep508Spec.indexOf(';'); - if (idx >= 0) { - String marker = pep508Spec.substring(idx + 1).trim(); - return marker.isEmpty() ? null : marker; - } - return null; - } - } diff --git a/rewrite-python/src/main/java/org/openrewrite/python/UpgradeTransitiveDependencyVersion.java b/rewrite-python/src/main/java/org/openrewrite/python/UpgradeTransitiveDependencyVersion.java index 7b6f286d97..3448ff467f 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/UpgradeTransitiveDependencyVersion.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/UpgradeTransitiveDependencyVersion.java @@ -19,29 +19,24 @@ import lombok.Value; import org.jspecify.annotations.Nullable; import org.openrewrite.*; -import org.openrewrite.marker.Markers; import org.openrewrite.python.internal.PyProjectHelper; import org.openrewrite.python.internal.PythonDependencyExecutionContextView; import org.openrewrite.python.marker.PythonResolutionResult; -import org.openrewrite.python.marker.PythonResolutionResult.Dependency; -import org.openrewrite.toml.TomlIsoVisitor; -import org.openrewrite.toml.tree.Space; +import org.openrewrite.python.trait.PythonDependencyFile; import org.openrewrite.toml.tree.Toml; -import org.openrewrite.toml.tree.TomlRightPadded; -import org.openrewrite.toml.tree.TomlType; -import java.util.*; - -import static org.openrewrite.Tree.randomId; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; /** - * Pin a transitive dependency version by adding or upgrading a constraint in the - * appropriate tool-specific section. The strategy depends on the detected package manager: - *
    - *
  • uv: uses {@code [tool.uv].constraint-dependencies}
  • - *
  • PDM: uses {@code [tool.pdm.overrides]}
  • - *
  • Other/unknown: adds as a direct dependency in {@code [project].dependencies}
  • - *
+ * Pin a transitive dependency version using the strategy appropriate for the file type + * and package manager. For {@code pyproject.toml}: uv uses + * {@code [tool.uv].constraint-dependencies}, PDM uses {@code [tool.pdm.overrides]}, + * and other managers add a direct dependency. For {@code requirements.txt} and + * {@code Pipfile}: appends the dependency. When uv is available, the uv.lock file + * is regenerated. */ @EqualsAndHashCode(callSuper = false) @Value @@ -69,23 +64,14 @@ public String getInstanceNameSuffix() { @Override public String getDescription() { - return "Pin a transitive dependency version using the appropriate strategy for the " + - "detected package manager: uv uses `[tool.uv].constraint-dependencies`, " + - "PDM uses `[tool.pdm.overrides]`, and other managers add a direct dependency."; - } - - enum Action { - NONE, - ADD_CONSTRAINT, - UPGRADE_CONSTRAINT, - ADD_PDM_OVERRIDE, - UPGRADE_PDM_OVERRIDE, - ADD_DIRECT_DEPENDENCY + return "Pin a transitive dependency version using the strategy appropriate for the file type " + + "and package manager. For `pyproject.toml`: uv uses `[tool.uv].constraint-dependencies`, " + + "PDM uses `[tool.pdm.overrides]`, and other managers add a direct dependency. " + + "For `requirements.txt` and `Pipfile`: appends the dependency."; } static class Accumulator { final Set projectsToUpdate = new HashSet<>(); - final Map actions = new HashMap<>(); } @Override @@ -95,317 +81,76 @@ public Accumulator getInitialValue(ExecutionContext ctx) { @Override public TreeVisitor getScanner(Accumulator acc) { - return new TomlIsoVisitor() { + return new TreeVisitor() { @Override - public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) { - String sourcePath = document.getSourcePath().toString(); - - if (sourcePath.endsWith("uv.lock")) { - PythonDependencyExecutionContextView.view(ctx).getExistingLockContents().put( - PyProjectHelper.correspondingPyprojectPath(sourcePath), - document.printAll()); - return document; + public @Nullable Tree preVisit(Tree tree, ExecutionContext ctx) { + if (!(tree instanceof SourceFile)) { + return tree; } - - if (!sourcePath.endsWith("pyproject.toml")) { - return document; + stopAfterPreVisit(); + SourceFile sourceFile = (SourceFile) tree; + if (tree instanceof Toml.Document && sourceFile.getSourcePath().toString().endsWith("uv.lock")) { + PythonDependencyExecutionContextView.view(ctx).getExistingLockContents().put( + PyProjectHelper.correspondingPyprojectPath(sourceFile.getSourcePath().toString()), + ((Toml.Document) tree).printAll()); + return tree; } - Optional resolution = document.getMarkers() - .findFirst(PythonResolutionResult.class); - if (!resolution.isPresent()) { - return document; + PythonDependencyFile trait = new PythonDependencyFile.Matcher().get(getCursor()).orElse(null); + if (trait == null) { + return tree; } - - PythonResolutionResult marker = resolution.get(); + PythonResolutionResult marker = trait.getMarker(); // Skip if this is a direct dependency if (marker.findDependency(packageName) != null) { - return document; - } - - PythonResolutionResult.PackageManager pm = marker.getPackageManager(); - Action action = null; - - if (pm == PythonResolutionResult.PackageManager.Uv) { - // Uv: require resolved deps, use constraint-dependencies - if (marker.getResolvedDependency(packageName) == null) { - return document; - } - Dependency existing = PyProjectHelper.findDependencyInScope( - marker, packageName, "tool.uv.constraint-dependencies", null); - if (existing == null) { - action = Action.ADD_CONSTRAINT; - } else if (!PyProjectHelper.normalizeVersionConstraint(version).equals(existing.getVersionConstraint())) { - action = Action.UPGRADE_CONSTRAINT; - } - } else if (pm == PythonResolutionResult.PackageManager.Pdm) { - // PDM: use tool.pdm.overrides - Dependency existing = PyProjectHelper.findDependencyInScope( - marker, packageName, "tool.pdm.overrides", null); - if (existing == null) { - action = Action.ADD_PDM_OVERRIDE; - } else if (!PyProjectHelper.normalizeVersionConstraint(version).equals(existing.getVersionConstraint())) { - action = Action.UPGRADE_PDM_OVERRIDE; - } - } else { - // Fallback: add as direct dependency - action = Action.ADD_DIRECT_DEPENDENCY; + return tree; } - if (action == null) { - return document; + // For uv: skip if not in the resolved dependency tree + if (marker.getPackageManager() == PythonResolutionResult.PackageManager.Uv && + marker.getResolvedDependency(packageName) == null) { + return tree; } - acc.actions.put(sourcePath, action); - acc.projectsToUpdate.add(sourcePath); - return document; + acc.projectsToUpdate.add(sourceFile.getSourcePath().toString()); + return tree; } }; } @Override public TreeVisitor getVisitor(Accumulator acc) { - return new TomlIsoVisitor() { + return new TreeVisitor() { @Override - public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) { - String sourcePath = document.getSourcePath().toString(); - - if (sourcePath.endsWith("pyproject.toml") && acc.projectsToUpdate.contains(sourcePath)) { - Action action = acc.actions.get(sourcePath); - if (action == Action.ADD_CONSTRAINT) { - return addToArray(document, ctx, acc, "tool.uv.constraint-dependencies"); - } else if (action == Action.UPGRADE_CONSTRAINT) { - return upgradeConstraint(document, ctx, acc); - } else if (action == Action.ADD_PDM_OVERRIDE) { - return addPdmOverride(document, ctx, acc); - } else if (action == Action.UPGRADE_PDM_OVERRIDE) { - return upgradePdmOverride(document, ctx, acc); - } else if (action == Action.ADD_DIRECT_DEPENDENCY) { - return addToArray(document, ctx, acc, null); + public @Nullable Tree preVisit(Tree tree, ExecutionContext ctx) { + if (!(tree instanceof SourceFile)) { + return tree; + } + stopAfterPreVisit(); + SourceFile sourceFile = (SourceFile) tree; + String sourcePath = sourceFile.getSourcePath().toString(); + + if (acc.projectsToUpdate.contains(sourcePath)) { + PythonDependencyFile trait = new PythonDependencyFile.Matcher().get(getCursor()).orElse(null); + if (trait != null) { + String normalizedName = PythonResolutionResult.normalizeName(packageName); + Map pins = Collections.singletonMap(normalizedName, version); + PythonDependencyFile updated = trait.withPinnedTransitiveDependencies(pins); + if (updated.getTree() != tree) { + return updated.afterModification(ctx); + } } } - if (sourcePath.endsWith("uv.lock")) { - Toml.Document updatedLock = PyProjectHelper.maybeUpdateUvLock(document, ctx); + if (tree instanceof Toml.Document && sourcePath.endsWith("uv.lock")) { + Toml.Document updatedLock = PyProjectHelper.maybeUpdateUvLock((Toml.Document) tree, ctx); if (updatedLock != null) { return updatedLock; } } - return document; + return tree; } }; } - - private Toml.Document addToArray(Toml.Document document, ExecutionContext ctx, Accumulator acc, @Nullable String scope) { - String normalizedVersion = PyProjectHelper.normalizeVersionConstraint(version); - String pep508 = packageName + normalizedVersion; - - Toml.Document updated = (Toml.Document) new TomlIsoVisitor() { - @Override - public Toml.Array visitArray(Toml.Array array, ExecutionContext ctx) { - Toml.Array a = super.visitArray(array, ctx); - - if (!PyProjectHelper.isInsideDependencyArray(getCursor(), scope, null)) { - return a; - } - - Toml.Literal newLiteral = new Toml.Literal( - randomId(), - Space.EMPTY, - Markers.EMPTY, - TomlType.Primitive.String, - "\"" + pep508 + "\"", - pep508 - ); - - List> existingPadded = a.getPadding().getValues(); - List> newPadded = new ArrayList<>(); - - boolean isEmpty = existingPadded.size() == 1 && - existingPadded.get(0).getElement() instanceof Toml.Empty; - if (existingPadded.isEmpty() || isEmpty) { - newPadded.add(new TomlRightPadded<>(newLiteral, Space.EMPTY, Markers.EMPTY)); - } else { - TomlRightPadded lastPadded = existingPadded.get(existingPadded.size() - 1); - boolean hasTrailingComma = lastPadded.getElement() instanceof Toml.Empty; - - if (hasTrailingComma) { - int lastRealIdx = existingPadded.size() - 2; - Toml lastRealElement = existingPadded.get(lastRealIdx).getElement(); - Toml.Literal formattedLiteral = newLiteral.withPrefix(lastRealElement.getPrefix()); - - for (int i = 0; i <= lastRealIdx; i++) { - newPadded.add(existingPadded.get(i)); - } - newPadded.add(new TomlRightPadded<>(formattedLiteral, Space.EMPTY, Markers.EMPTY)); - newPadded.add(lastPadded); - } else { - Toml lastElement = lastPadded.getElement(); - Space newPrefix = lastElement.getPrefix().getWhitespace().contains("\n") - ? lastElement.getPrefix() - : Space.SINGLE_SPACE; - Toml.Literal formattedLiteral = newLiteral.withPrefix(newPrefix); - - for (int i = 0; i < existingPadded.size() - 1; i++) { - newPadded.add(existingPadded.get(i)); - } - newPadded.add(lastPadded.withAfter(Space.EMPTY)); - newPadded.add(new TomlRightPadded<>(formattedLiteral, lastPadded.getAfter(), Markers.EMPTY)); - } - } - - return a.getPadding().withValues(newPadded); - } - }.visitNonNull(document, ctx); - - if (updated != document) { - updated = PyProjectHelper.regenerateLockAndRefreshMarker(updated, ctx); - } - - return updated; - } - - private Toml.Document upgradeConstraint(Toml.Document document, ExecutionContext ctx, Accumulator acc) { - String normalizedName = PythonResolutionResult.normalizeName(packageName); - - Toml.Document updated = (Toml.Document) new TomlIsoVisitor() { - @Override - public Toml.Literal visitLiteral(Toml.Literal literal, ExecutionContext ctx) { - Toml.Literal l = super.visitLiteral(literal, ctx); - if (l.getType() != TomlType.Primitive.String) { - return l; - } - - Object val = l.getValue(); - if (!(val instanceof String)) { - return l; - } - - // Check if we're inside [tool.uv].constraint-dependencies - if (!isInsideConstraintDependencies()) { - return l; - } - - String spec = (String) val; - String depName = PyProjectHelper.extractPackageName(spec); - if (depName == null || !PythonResolutionResult.normalizeName(depName).equals(normalizedName)) { - return l; - } - - String extras = UpgradeDependencyVersion.extractExtras(spec); - String marker = UpgradeDependencyVersion.extractMarker(spec); - - StringBuilder sb = new StringBuilder(depName); - if (extras != null) { - sb.append('[').append(extras).append(']'); - } - sb.append(PyProjectHelper.normalizeVersionConstraint(version)); - if (marker != null) { - sb.append("; ").append(marker); - } - - String newSpec = sb.toString(); - return l.withSource("\"" + newSpec + "\"").withValue(newSpec); - } - - private boolean isInsideConstraintDependencies() { - Cursor c = getCursor(); - while (c != null) { - if (c.getValue() instanceof Toml.Array) { - return PyProjectHelper.isInsideDependencyArray(c, "tool.uv.constraint-dependencies", null); - } - c = c.getParent(); - } - return false; - } - }.visitNonNull(document, ctx); - - if (updated != document) { - updated = PyProjectHelper.regenerateLockAndRefreshMarker(updated, ctx); - } - - return updated; - } - - private Toml.Document addPdmOverride(Toml.Document document, ExecutionContext ctx, Accumulator acc) { - Toml.Document updated = (Toml.Document) new TomlIsoVisitor() { - @Override - public Toml.Table visitTable(Toml.Table table, ExecutionContext ctx) { - Toml.Table t = super.visitTable(table, ctx); - if (t.getName() == null || !"tool.pdm.overrides".equals(t.getName().getName())) { - return t; - } - - // Build a new KeyValue: packageName = "version" - String normalizedVersion = PyProjectHelper.normalizeVersionConstraint(version); - Toml.Identifier key = new Toml.Identifier( - randomId(), Space.EMPTY, Markers.EMPTY, packageName, packageName); - Toml.Literal value = new Toml.Literal( - randomId(), Space.SINGLE_SPACE, Markers.EMPTY, - TomlType.Primitive.String, "\"" + normalizedVersion + "\"", normalizedVersion); - Toml.KeyValue newKv = new Toml.KeyValue( - randomId(), Space.EMPTY, Markers.EMPTY, - new TomlRightPadded<>(key, Space.SINGLE_SPACE, Markers.EMPTY), - value); - - // Determine prefix for new entry - List values = t.getValues(); - Space entryPrefix; - if (!values.isEmpty()) { - entryPrefix = values.get(values.size() - 1).getPrefix(); - } else { - entryPrefix = Space.format("\n"); - } - newKv = newKv.withPrefix(entryPrefix); - - List newValues = new ArrayList<>(values); - newValues.add(newKv); - return t.withValues(newValues); - } - }.visitNonNull(document, ctx); - - if (updated != document) { - updated = PyProjectHelper.regenerateLockAndRefreshMarker(updated, ctx); - } - - return updated; - } - - private Toml.Document upgradePdmOverride(Toml.Document document, ExecutionContext ctx, Accumulator acc) { - String normalizedName = PythonResolutionResult.normalizeName(packageName); - - Toml.Document updated = (Toml.Document) new TomlIsoVisitor() { - @Override - public Toml.KeyValue visitKeyValue(Toml.KeyValue keyValue, ExecutionContext ctx) { - Toml.KeyValue kv = super.visitKeyValue(keyValue, ctx); - - if (!PyProjectHelper.isInsidePdmOverridesTable(getCursor())) { - return kv; - } - - if (!(kv.getKey() instanceof Toml.Identifier)) { - return kv; - } - String keyName = ((Toml.Identifier) kv.getKey()).getName(); - if (!PythonResolutionResult.normalizeName(keyName).equals(normalizedName)) { - return kv; - } - - if (!(kv.getValue() instanceof Toml.Literal)) { - return kv; - } - - Toml.Literal literal = (Toml.Literal) kv.getValue(); - String normalizedVersion = PyProjectHelper.normalizeVersionConstraint(version); - return kv.withValue(literal.withSource("\"" + normalizedVersion + "\"").withValue(normalizedVersion)); - } - }.visitNonNull(document, ctx); - - if (updated != document) { - updated = PyProjectHelper.regenerateLockAndRefreshMarker(updated, ctx); - } - - return updated; - } } diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/PipfileFile.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/PipfileFile.java new file mode 100644 index 0000000000..8255c5507d --- /dev/null +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/PipfileFile.java @@ -0,0 +1,355 @@ +/* + * Copyright 2026 the original author or authors. + * + * Moderne Proprietary. Only for use by Moderne customers under the terms of a commercial contract. + */ +package org.openrewrite.python.trait; + +import lombok.Value; +import org.jspecify.annotations.Nullable; +import org.openrewrite.Cursor; +import org.openrewrite.ExecutionContext; +import org.openrewrite.Tree; +import org.openrewrite.marker.Markers; +import org.openrewrite.marker.SearchResult; +import org.openrewrite.python.internal.PyProjectHelper; +import org.openrewrite.python.marker.PythonResolutionResult; +import org.openrewrite.toml.TomlIsoVisitor; +import org.openrewrite.toml.tree.Space; +import org.openrewrite.toml.tree.Toml; +import org.openrewrite.toml.tree.TomlRightPadded; +import org.openrewrite.toml.tree.TomlType; +import org.openrewrite.trait.SimpleTraitMatcher; + +import java.util.*; + +/** + * Trait implementation for Pipfile dependency files. + * Pipfile uses key-value tables: {@code [packages]} for production and + * {@code [dev-packages]} for development dependencies. + */ +@Value +public class PipfileFile implements PythonDependencyFile { + + Cursor cursor; + PythonResolutionResult marker; + + @Override + public PipfileFile withUpgradedVersions(Map upgrades, @Nullable String scope, @Nullable String groupName) { + Toml.Document doc = (Toml.Document) getTree(); + Toml.Document result = (Toml.Document) new TomlIsoVisitor>() { + @Override + public Toml.KeyValue visitKeyValue(Toml.KeyValue keyValue, Map u) { + Toml.KeyValue kv = super.visitKeyValue(keyValue, u); + if (!isInsideTargetTable(getCursor(), scope)) { + return kv; + } + if (!(kv.getKey() instanceof Toml.Identifier)) { + return kv; + } + String pkgName = ((Toml.Identifier) kv.getKey()).getName(); + String newVersion = PythonDependencyFile.getByNormalizedName(u, pkgName); + if (newVersion == null) { + return kv; + } + return updateKeyValueVersion(kv, newVersion); + } + }.visitNonNull(doc, upgrades); + if (result != doc) { + PythonResolutionResult updatedMarker = PythonDependencyFile.updateResolvedVersions(marker, upgrades); + result = result.withMarkers(result.getMarkers() + .removeByType(PythonResolutionResult.class) + .addIfAbsent(updatedMarker)); + return new PipfileFile(new Cursor(cursor.getParentOrThrow(), result), updatedMarker); + } + return this; + } + + @Override + public PipfileFile withAddedDependencies(Map additions, @Nullable String scope, @Nullable String groupName) { + String tableName = resolveTableName(scope); + Toml.Document doc = (Toml.Document) getTree(); + Toml.Document original = doc; + for (Map.Entry entry : additions.entrySet()) { + String normalizedName = PythonResolutionResult.normalizeName(entry.getKey()); + if (!hasDependencyInTable(doc, tableName, normalizedName)) { + doc = addToTable(doc, tableName, entry.getKey(), entry.getValue()); + } + } + if (doc != original) { + PythonResolutionResult updatedMarker = PythonDependencyFile.updateResolvedVersions(marker, additions); + doc = doc.withMarkers(doc.getMarkers() + .removeByType(PythonResolutionResult.class) + .addIfAbsent(updatedMarker)); + return new PipfileFile(new Cursor(cursor.getParentOrThrow(), doc), updatedMarker); + } + return this; + } + + @Override + public PipfileFile withRemovedDependencies(Set packageNames, @Nullable String scope, @Nullable String groupName) { + Set normalizedNames = new HashSet<>(); + for (String name : packageNames) { + normalizedNames.add(PythonResolutionResult.normalizeName(name)); + } + Toml.Document doc = (Toml.Document) getTree(); + Toml.Document result = (Toml.Document) new TomlIsoVisitor>() { + @Override + public Toml.Table visitTable(Toml.Table table, Set names) { + Toml.Table t = super.visitTable(table, names); + if (!isTargetTable(t, scope)) { + return t; + } + List newValues = new ArrayList<>(); + boolean changed = false; + for (Toml value : t.getValues()) { + if (value instanceof Toml.KeyValue) { + Toml.KeyValue kv = (Toml.KeyValue) value; + if (kv.getKey() instanceof Toml.Identifier) { + String keyName = ((Toml.Identifier) kv.getKey()).getName(); + if (names.contains(PythonResolutionResult.normalizeName(keyName))) { + changed = true; + continue; + } + } + } + newValues.add(value); + } + return changed ? t.withValues(newValues) : t; + } + }.visitNonNull(doc, normalizedNames); + if (result != doc) { + return new PipfileFile(new Cursor(cursor.getParentOrThrow(), result), marker); + } + return this; + } + + @Override + public PipfileFile withChangedDependency(String oldPackageName, String newPackageName, @Nullable String newVersion) { + String normalizedOld = PythonResolutionResult.normalizeName(oldPackageName); + Toml.Document doc = (Toml.Document) getTree(); + Toml.Document result = (Toml.Document) new TomlIsoVisitor() { + @Override + public Toml.KeyValue visitKeyValue(Toml.KeyValue keyValue, Integer p) { + Toml.KeyValue kv = super.visitKeyValue(keyValue, p); + if (!(kv.getKey() instanceof Toml.Identifier)) { + return kv; + } + String keyName = ((Toml.Identifier) kv.getKey()).getName(); + if (!PythonResolutionResult.normalizeName(keyName).equals(normalizedOld)) { + return kv; + } + // Check we're inside [packages] or [dev-packages] + if (!isInsideTargetTable(getCursor(), null)) { + return kv; + } + Toml.Identifier newKey = ((Toml.Identifier) kv.getKey()) + .withName(newPackageName) + .withSource(newPackageName); + kv = kv.getPadding().withKey(kv.getPadding().getKey().withElement(newKey)); + if (newVersion != null) { + kv = updateKeyValueVersion(kv, newVersion); + } + return kv; + } + }.visitNonNull(doc, 0); + if (result != doc) { + return new PipfileFile(new Cursor(cursor.getParentOrThrow(), result), marker); + } + return this; + } + + @Override + public PipfileFile withPinnedTransitiveDependencies(Map pins) { + // Pipfile has no constraint mechanism — add to [packages] + return withAddedDependencies(pins, "packages", null); + } + + @Override + public PipfileFile withDependencySearchMarkers(Map packageMessages, ExecutionContext ctx) { + Toml.Document doc = (Toml.Document) getTree(); + Toml.Document result = (Toml.Document) new TomlIsoVisitor>() { + @Override + public Toml.KeyValue visitKeyValue(Toml.KeyValue keyValue, Map msgs) { + Toml.KeyValue kv = super.visitKeyValue(keyValue, msgs); + if (!isInsideTargetTable(getCursor(), null)) { + return kv; + } + if (!(kv.getKey() instanceof Toml.Identifier)) { + return kv; + } + String pkgName = ((Toml.Identifier) kv.getKey()).getName(); + String message = PythonDependencyFile.getByNormalizedName(msgs, pkgName); + if (message != null) { + return SearchResult.found(kv, message); + } + return kv; + } + }.visitNonNull(doc, packageMessages); + if (result != doc) { + return new PipfileFile(new Cursor(cursor.getParentOrThrow(), result), marker); + } + return this; + } + + // region Helpers + + /** + * Resolve the target table name from the scope parameter. + * {@code null} defaults to "packages". + */ + private static String resolveTableName(@Nullable String scope) { + if (scope == null || scope.isEmpty() || "packages".equals(scope)) { + return "packages"; + } + return scope; + } + + /** + * Check if a table matches the target scope. + * When scope is null, matches both "packages" and "dev-packages". + */ + private static boolean isTargetTable(Toml.Table table, @Nullable String scope) { + if (table.getName() == null) { + return false; + } + String name = table.getName().getName(); + if (scope == null) { + return "packages".equals(name) || "dev-packages".equals(name); + } + return resolveTableName(scope).equals(name); + } + + /** + * Check if the cursor is inside a target Pipfile table. + * When scope is null, matches both "packages" and "dev-packages". + */ + private static boolean isInsideTargetTable(Cursor cursor, @Nullable String scope) { + Cursor c = cursor; + while (c != null) { + Object val = c.getValue(); + if (val instanceof Toml.Table) { + return isTargetTable((Toml.Table) val, scope); + } + c = c.getParent(); + } + return false; + } + + private static boolean hasDependencyInTable(Toml.Document doc, String tableName, String normalizedName) { + for (Toml value : doc.getValues()) { + if (value instanceof Toml.Table) { + Toml.Table table = (Toml.Table) value; + if (table.getName() != null && tableName.equals(table.getName().getName())) { + for (Toml entry : table.getValues()) { + if (entry instanceof Toml.KeyValue) { + Toml.KeyValue kv = (Toml.KeyValue) entry; + if (kv.getKey() instanceof Toml.Identifier && + PythonResolutionResult.normalizeName( + ((Toml.Identifier) kv.getKey()).getName()).equals(normalizedName)) { + return true; + } + } + } + } + } + } + return false; + } + + private static Toml.Document addToTable(Toml.Document doc, String tableName, String packageName, String version) { + String normalizedVersion = PyProjectHelper.normalizeVersionConstraint(version); + return (Toml.Document) new TomlIsoVisitor() { + @Override + public Toml.Table visitTable(Toml.Table table, Integer p) { + Toml.Table t = super.visitTable(table, p); + if (t.getName() == null || !tableName.equals(t.getName().getName())) { + return t; + } + + Toml.Identifier key = new Toml.Identifier( + Tree.randomId(), Space.EMPTY, Markers.EMPTY, packageName, packageName); + Toml.Literal value = new Toml.Literal( + Tree.randomId(), Space.SINGLE_SPACE, Markers.EMPTY, + TomlType.Primitive.String, "\"" + normalizedVersion + "\"", normalizedVersion); + Toml.KeyValue newKv = new Toml.KeyValue( + Tree.randomId(), Space.EMPTY, Markers.EMPTY, + new TomlRightPadded<>(key, Space.SINGLE_SPACE, Markers.EMPTY), + value); + + List values = t.getValues(); + Space entryPrefix = !values.isEmpty() + ? values.get(values.size() - 1).getPrefix() + : Space.format("\n"); + newKv = newKv.withPrefix(entryPrefix); + + List newValues = new ArrayList<>(values); + newValues.add(newKv); + return t.withValues(newValues); + } + }.visitNonNull(doc, 0); + } + + /** + * Update the version in a key-value pair, handling both simple literals + * ({@code requests = ">=2.28.0"}) and inline tables + * ({@code django = {version = ">=3.2", extras = ["postgres"]}}). + */ + private static Toml.KeyValue updateKeyValueVersion(Toml.KeyValue kv, String newVersion) { + String normalizedVersion = PyProjectHelper.normalizeVersionConstraint(newVersion); + if (kv.getValue() instanceof Toml.Literal) { + Toml.Literal literal = (Toml.Literal) kv.getValue(); + if (normalizedVersion.equals(literal.getValue())) { + return kv; + } + return kv.withValue(literal.withSource("\"" + normalizedVersion + "\"").withValue(normalizedVersion)); + } + if (kv.getValue() instanceof Toml.Table) { + // Inline table: update the "version" key inside + Toml.Table inlineTable = (Toml.Table) kv.getValue(); + List newValues = new ArrayList<>(); + boolean changed = false; + for (Toml inner : inlineTable.getValues()) { + if (inner instanceof Toml.KeyValue) { + Toml.KeyValue innerKv = (Toml.KeyValue) inner; + if (innerKv.getKey() instanceof Toml.Identifier && + "version".equals(((Toml.Identifier) innerKv.getKey()).getName()) && + innerKv.getValue() instanceof Toml.Literal) { + Toml.Literal literal = (Toml.Literal) innerKv.getValue(); + if (!normalizedVersion.equals(literal.getValue())) { + newValues.add(innerKv.withValue( + literal.withSource("\"" + normalizedVersion + "\"").withValue(normalizedVersion))); + changed = true; + continue; + } + } + } + newValues.add(inner); + } + if (changed) { + return kv.withValue(inlineTable.withValues(newValues)); + } + } + return kv; + } + + // endregion + + public static class Matcher extends SimpleTraitMatcher { + @Override + protected @Nullable PipfileFile test(Cursor cursor) { + Object value = cursor.getValue(); + if (value instanceof Toml.Document) { + Toml.Document doc = (Toml.Document) value; + if ("Pipfile".equals(doc.getSourcePath().getFileName().toString())) { + PythonResolutionResult marker = doc.getMarkers() + .findFirst(PythonResolutionResult.class).orElse(null); + if (marker != null) { + return new PipfileFile(cursor, marker); + } + } + } + return null; + } + } +} diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java new file mode 100644 index 0000000000..ec4ec21ad0 --- /dev/null +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java @@ -0,0 +1,445 @@ +/* + * Copyright 2026 the original author or authors. + * + * Moderne Proprietary. Only for use by Moderne customers under the terms of a commercial contract. + */ +package org.openrewrite.python.trait; + +import lombok.Value; +import org.jspecify.annotations.Nullable; +import org.openrewrite.Cursor; +import org.openrewrite.ExecutionContext; +import org.openrewrite.SourceFile; +import org.openrewrite.Tree; +import org.openrewrite.marker.Markers; +import org.openrewrite.marker.SearchResult; +import org.openrewrite.python.internal.PyProjectHelper; +import org.openrewrite.python.marker.PythonResolutionResult; +import org.openrewrite.toml.TomlIsoVisitor; +import org.openrewrite.toml.tree.Space; +import org.openrewrite.toml.tree.Toml; +import org.openrewrite.toml.tree.TomlRightPadded; +import org.openrewrite.toml.tree.TomlType; +import org.openrewrite.trait.SimpleTraitMatcher; + +import java.util.*; + +@Value +public class PyProjectFile implements PythonDependencyFile { + + Cursor cursor; + PythonResolutionResult marker; + + @Override + public PyProjectFile withUpgradedVersions(Map upgrades, @Nullable String scope, @Nullable String groupName) { + Toml.Document doc = (Toml.Document) getTree(); + Toml.Document result = (Toml.Document) new TomlIsoVisitor>() { + @Override + public Toml.Literal visitLiteral(Toml.Literal literal, Map u) { + if (!isInsideTargetArray(getCursor(), scope, groupName)) { + return literal; + } + + String spec = literal.getValue().toString(); + String packageName = PyProjectHelper.extractPackageName(spec); + if (packageName == null) { + return literal; + } + + String fixVersion = PythonDependencyFile.getByNormalizedName(u, packageName); + if (fixVersion == null) { + return literal; + } + + String newSpec = PythonDependencyFile.rewritePep508Spec(spec, packageName, fixVersion); + if (!newSpec.equals(spec)) { + return literal.withSource("\"" + newSpec + "\"").withValue(newSpec); + } + return literal; + } + }.visitNonNull(doc, upgrades); + if (result != doc) { + PythonResolutionResult updatedMarker = PythonDependencyFile.updateResolvedVersions(marker, upgrades); + result = result.withMarkers(result.getMarkers() + .removeByType(PythonResolutionResult.class) + .addIfAbsent(updatedMarker)); + return new PyProjectFile(new Cursor(cursor.getParentOrThrow(), result), updatedMarker); + } + return this; + } + + @Override + public PyProjectFile withAddedDependencies(Map additions, @Nullable String scope, @Nullable String groupName) { + Toml.Document doc = (Toml.Document) getTree(); + Toml.Document original = doc; + for (Map.Entry entry : additions.entrySet()) { + if (PyProjectHelper.findDependencyInScope(marker, entry.getKey(), scope, groupName) == null) { + String pep508 = entry.getKey() + PyProjectHelper.normalizeVersionConstraint(entry.getValue()); + doc = addDependencyToArray(doc, pep508, scope, groupName); + } + } + if (doc != original) { + PythonResolutionResult updatedMarker = PythonDependencyFile.updateResolvedVersions(marker, additions); + doc = doc.withMarkers(doc.getMarkers() + .removeByType(PythonResolutionResult.class) + .addIfAbsent(updatedMarker)); + return new PyProjectFile(new Cursor(cursor.getParentOrThrow(), doc), updatedMarker); + } + return this; + } + + @Override + public PyProjectFile withRemovedDependencies(Set packageNames, @Nullable String scope, @Nullable String groupName) { + Set normalizedNames = new HashSet<>(); + for (String name : packageNames) { + normalizedNames.add(PythonResolutionResult.normalizeName(name)); + } + Toml.Document doc = (Toml.Document) getTree(); + Toml.Document result = (Toml.Document) new TomlIsoVisitor>() { + @Override + public Toml.Array visitArray(Toml.Array array, Set names) { + Toml.Array a = super.visitArray(array, names); + if (!PyProjectHelper.isInsideDependencyArray(getCursor(), scope, groupName)) { + return a; + } + + List> existingPadded = a.getPadding().getValues(); + List> newPadded = new ArrayList<>(); + boolean found = false; + int removedIdx = -1; + + for (int i = 0; i < existingPadded.size(); i++) { + TomlRightPadded padded = existingPadded.get(i); + Toml element = padded.getElement(); + + if (element instanceof Toml.Literal) { + Object val = ((Toml.Literal) element).getValue(); + if (val instanceof String) { + String depName = PyProjectHelper.extractPackageName((String) val); + if (depName != null && names.contains(PythonResolutionResult.normalizeName(depName))) { + if (!found) { + removedIdx = i; + } + found = true; + continue; + } + } + } + newPadded.add(padded); + } + + if (!found) { + return a; + } + + if (removedIdx == 0 && !newPadded.isEmpty()) { + TomlRightPadded first = newPadded.get(0); + if (!(first.getElement() instanceof Toml.Empty)) { + Space originalPrefix = existingPadded.get(removedIdx).getElement().getPrefix(); + newPadded.set(0, first.map(el -> el.withPrefix(originalPrefix))); + } + } + + return a.getPadding().withValues(newPadded); + } + }.visitNonNull(doc, normalizedNames); + if (result != doc) { + return new PyProjectFile(new Cursor(cursor.getParentOrThrow(), result), marker); + } + return this; + } + + @Override + public PyProjectFile withChangedDependency(String oldPackageName, String newPackageName, @Nullable String newVersion) { + String normalizedOld = PythonResolutionResult.normalizeName(oldPackageName); + Toml.Document doc = (Toml.Document) getTree(); + Toml.Document result = (Toml.Document) new TomlIsoVisitor() { + @Override + public Toml.Literal visitLiteral(Toml.Literal literal, Integer p) { + if (literal.getType() != TomlType.Primitive.String) { + return literal; + } + Object val = literal.getValue(); + if (!(val instanceof String)) { + return literal; + } + String spec = (String) val; + String depName = PyProjectHelper.extractPackageName(spec); + if (depName == null || !PythonResolutionResult.normalizeName(depName).equals(normalizedOld)) { + return literal; + } + + String newSpec = buildChangedSpec(spec, depName, newPackageName, newVersion); + return literal.withSource("\"" + newSpec + "\"").withValue(newSpec); + } + }.visitNonNull(doc, 0); + if (result != doc) { + return new PyProjectFile(new Cursor(cursor.getParentOrThrow(), result), marker); + } + return this; + } + + private static String buildChangedSpec(String oldSpec, String oldName, String newName, @Nullable String newVer) { + StringBuilder sb = new StringBuilder(newName); + + // Preserve extras + int extrasStart = oldSpec.indexOf('[', oldName.length()); + int extrasEnd = oldSpec.indexOf(']', oldName.length()); + if (extrasStart >= 0 && extrasEnd > extrasStart) { + sb.append(oldSpec, extrasStart, extrasEnd + 1); + } + + if (newVer != null) { + sb.append(PyProjectHelper.normalizeVersionConstraint(newVer)); + } else { + // Preserve original version constraint + String remaining = oldSpec.substring(oldName.length()).trim(); + if (remaining.startsWith("[")) { + int end = remaining.indexOf(']'); + if (end >= 0) { + remaining = remaining.substring(end + 1).trim(); + } + } + int markerIdx = remaining.indexOf(';'); + String versionPart = markerIdx >= 0 ? remaining.substring(0, markerIdx).trim() : remaining.trim(); + if (!versionPart.isEmpty()) { + sb.append(versionPart); + } + } + + // Preserve environment markers + int semiIdx = oldSpec.indexOf(';'); + if (semiIdx >= 0) { + sb.append("; ").append(oldSpec.substring(semiIdx + 1).trim()); + } + + return sb.toString(); + } + + @Override + public PyProjectFile withPinnedTransitiveDependencies(Map pins) { + PythonResolutionResult.PackageManager pm = marker.getPackageManager(); + if (pm == PythonResolutionResult.PackageManager.Uv) { + return pinViaArrayScope(pins, "tool.uv.constraint-dependencies"); + } else if (pm == PythonResolutionResult.PackageManager.Pdm) { + return pinViaPdmOverrides(pins); + } else { + // Fallback: add as direct dependency + return withAddedDependencies(pins, null, null); + } + } + + private PyProjectFile pinViaArrayScope(Map pins, String scope) { + PyProjectFile current = this; + for (Map.Entry entry : pins.entrySet()) { + PythonResolutionResult.Dependency existing = PyProjectHelper.findDependencyInScope( + current.marker, entry.getKey(), scope, null); + if (existing == null) { + current = current.withAddedDependencies( + Collections.singletonMap(entry.getKey(), entry.getValue()), scope, null); + } else { + current = current.withUpgradedVersions( + Collections.singletonMap(entry.getKey(), entry.getValue()), scope, null); + } + } + return current; + } + + private PyProjectFile pinViaPdmOverrides(Map pins) { + Toml.Document doc = (Toml.Document) getTree(); + Toml.Document updated = doc; + for (Map.Entry entry : pins.entrySet()) { + PythonResolutionResult.Dependency existing = PyProjectHelper.findDependencyInScope( + marker, entry.getKey(), "tool.pdm.overrides", null); + if (existing == null) { + updated = addPdmOverride(updated, entry.getKey(), entry.getValue()); + } else { + updated = upgradePdmOverride(updated, entry.getKey(), entry.getValue()); + } + } + if (updated != doc) { + return new PyProjectFile(new Cursor(cursor.getParentOrThrow(), updated), marker); + } + return this; + } + + private static Toml.Document addPdmOverride(Toml.Document doc, String packageName, String version) { + String normalizedVersion = PyProjectHelper.normalizeVersionConstraint(version); + return (Toml.Document) new TomlIsoVisitor() { + @Override + public Toml.Table visitTable(Toml.Table table, Integer p) { + Toml.Table t = super.visitTable(table, p); + if (t.getName() == null || !"tool.pdm.overrides".equals(t.getName().getName())) { + return t; + } + + Toml.Identifier key = new Toml.Identifier( + Tree.randomId(), Space.EMPTY, Markers.EMPTY, packageName, packageName); + Toml.Literal value = new Toml.Literal( + Tree.randomId(), Space.SINGLE_SPACE, Markers.EMPTY, + TomlType.Primitive.String, "\"" + normalizedVersion + "\"", normalizedVersion); + Toml.KeyValue newKv = new Toml.KeyValue( + Tree.randomId(), Space.EMPTY, Markers.EMPTY, + new TomlRightPadded<>(key, Space.SINGLE_SPACE, Markers.EMPTY), + value); + + List values = t.getValues(); + Space entryPrefix = !values.isEmpty() + ? values.get(values.size() - 1).getPrefix() + : Space.format("\n"); + newKv = newKv.withPrefix(entryPrefix); + + List newValues = new ArrayList<>(values); + newValues.add(newKv); + return t.withValues(newValues); + } + }.visitNonNull(doc, 0); + } + + private static Toml.Document upgradePdmOverride(Toml.Document doc, String packageName, String version) { + String normalizedName = PythonResolutionResult.normalizeName(packageName); + String normalizedVersion = PyProjectHelper.normalizeVersionConstraint(version); + return (Toml.Document) new TomlIsoVisitor() { + @Override + public Toml.KeyValue visitKeyValue(Toml.KeyValue keyValue, Integer p) { + Toml.KeyValue kv = super.visitKeyValue(keyValue, p); + if (!PyProjectHelper.isInsidePdmOverridesTable(getCursor())) { + return kv; + } + if (!(kv.getKey() instanceof Toml.Identifier)) { + return kv; + } + String keyName = ((Toml.Identifier) kv.getKey()).getName(); + if (!PythonResolutionResult.normalizeName(keyName).equals(normalizedName)) { + return kv; + } + if (!(kv.getValue() instanceof Toml.Literal)) { + return kv; + } + Toml.Literal literal = (Toml.Literal) kv.getValue(); + if (normalizedVersion.equals(literal.getValue())) { + return kv; + } + return kv.withValue(literal.withSource("\"" + normalizedVersion + "\"").withValue(normalizedVersion)); + } + }.visitNonNull(doc, 0); + } + + private static Toml.Document addDependencyToArray(Toml.Document d, String pep508, + @Nullable String scope, @Nullable String groupName) { + return (Toml.Document) new TomlIsoVisitor() { + @Override + public Toml.Array visitArray(Toml.Array array, Integer p) { + Toml.Array a = super.visitArray(array, p); + if (!PyProjectHelper.isInsideDependencyArray(getCursor(), scope, groupName)) { + return a; + } + + Toml.Literal newLiteral = new Toml.Literal( + Tree.randomId(), Space.EMPTY, Markers.EMPTY, + TomlType.Primitive.String, "\"" + pep508 + "\"", pep508); + + List> existingPadded = a.getPadding().getValues(); + List> newPadded = new ArrayList<>(); + + boolean isEmpty = existingPadded.size() == 1 && + existingPadded.get(0).getElement() instanceof Toml.Empty; + if (existingPadded.isEmpty() || isEmpty) { + newPadded.add(new TomlRightPadded<>(newLiteral, Space.EMPTY, Markers.EMPTY)); + } else { + TomlRightPadded lastPadded = existingPadded.get(existingPadded.size() - 1); + boolean hasTrailingComma = lastPadded.getElement() instanceof Toml.Empty; + + if (hasTrailingComma) { + int lastRealIdx = existingPadded.size() - 2; + Toml lastRealElement = existingPadded.get(lastRealIdx).getElement(); + Toml.Literal formatted = newLiteral.withPrefix(lastRealElement.getPrefix()); + for (int i = 0; i <= lastRealIdx; i++) { + newPadded.add(existingPadded.get(i)); + } + newPadded.add(new TomlRightPadded<>(formatted, Space.EMPTY, Markers.EMPTY)); + newPadded.add(lastPadded); + } else { + Toml lastElement = lastPadded.getElement(); + Space newPrefix = lastElement.getPrefix().getWhitespace().contains("\n") ? + lastElement.getPrefix() : + Space.SINGLE_SPACE; + Toml.Literal formatted = newLiteral.withPrefix(newPrefix); + for (int i = 0; i < existingPadded.size() - 1; i++) { + newPadded.add(existingPadded.get(i)); + } + newPadded.add(lastPadded.withAfter(Space.EMPTY)); + newPadded.add(new TomlRightPadded<>(formatted, lastPadded.getAfter(), Markers.EMPTY)); + } + } + + return a.getPadding().withValues(newPadded); + } + }.visitNonNull(d, 0); + } + + @Override + public PyProjectFile withDependencySearchMarkers(Map packageMessages, ExecutionContext ctx) { + Toml.Document doc = (Toml.Document) getTree(); + Toml.Document result = (Toml.Document) new TomlIsoVisitor>() { + @Override + public Toml.Literal visitLiteral(Toml.Literal literal, Map msgs) { + if (!isInsideTargetArray(getCursor(), null, null)) { + return literal; + } + + String spec = literal.getValue().toString(); + String packageName = PyProjectHelper.extractPackageName(spec); + if (packageName == null) { + return literal; + } + + String normalizedName = PythonResolutionResult.normalizeName(packageName); + String message = msgs.get(normalizedName); + if (message != null) { + return SearchResult.found(literal, message); + } + return literal; + } + }.visitNonNull(doc, packageMessages); + if (result != doc) { + return new PyProjectFile(new Cursor(cursor.getParentOrThrow(), result), marker); + } + return this; + } + + @Override + public SourceFile afterModification(ExecutionContext ctx) { + // regenerateLockAndRefreshMarker already guards against missing lock content internally + return PyProjectHelper.regenerateLockAndRefreshMarker((Toml.Document) getTree(), ctx); + } + + private static boolean isInsideTargetArray(Cursor cursor, @Nullable String scope, @Nullable String groupName) { + Cursor c = cursor; + while (c != null) { + if (c.getValue() instanceof Toml.Array) { + return PyProjectHelper.isInsideDependencyArray(c, scope, groupName); + } + c = c.getParent(); + } + return false; + } + + public static class Matcher extends SimpleTraitMatcher { + @Override + protected @Nullable PyProjectFile test(Cursor cursor) { + Object value = cursor.getValue(); + if (value instanceof Toml.Document) { + Toml.Document doc = (Toml.Document) value; + if (doc.getSourcePath().toString().endsWith("pyproject.toml")) { + PythonResolutionResult marker = doc.getMarkers() + .findFirst(PythonResolutionResult.class).orElse(null); + if (marker != null) { + return new PyProjectFile(cursor, marker); + } + } + } + return null; + } + } +} diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java new file mode 100644 index 0000000000..4365661e2e --- /dev/null +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java @@ -0,0 +1,182 @@ +/* + * Copyright 2026 the original author or authors. + * + * Moderne Proprietary. Only for use by Moderne customers under the terms of a commercial contract. + */ +package org.openrewrite.python.trait; + +import org.jspecify.annotations.Nullable; +import org.openrewrite.Cursor; +import org.openrewrite.ExecutionContext; +import org.openrewrite.SourceFile; +import org.openrewrite.python.internal.PyProjectHelper; +import org.openrewrite.python.marker.PythonResolutionResult; +import org.openrewrite.trait.SimpleTraitMatcher; +import org.openrewrite.trait.Trait; + +import java.util.*; + +/** + * Trait for Python dependency files (pyproject.toml, requirements.txt, etc.). + * Use {@link org.openrewrite.python.internal.PyProjectHelper#extractPackageName(String)} + * for PEP 508 package name extraction. + */ + +public interface PythonDependencyFile extends Trait { + + PythonResolutionResult getMarker(); + + /** + * Upgrade version constraints for dependencies in the specified scope. + * + * @param upgrades package name → new version + * @param scope the TOML scope, or {@code null} for the default ({@code [project].dependencies}) + * @param groupName required for {@code "project.optional-dependencies"} or {@code "dependency-groups"} + */ + PythonDependencyFile withUpgradedVersions(Map upgrades, @Nullable String scope, @Nullable String groupName); + + /** + * Add dependencies to the specified scope. + * + * @param additions package name → version constraint (e.g. {@code "2.0"} or {@code ">=2.0"}) + * @param scope the TOML scope (e.g. {@code "project.optional-dependencies"}, + * {@code "dependency-groups"}), or {@code null} for the default + * ({@code [project].dependencies}) + * @param groupName required when scope is {@code "project.optional-dependencies"} + * or {@code "dependency-groups"}, otherwise {@code null} + */ + PythonDependencyFile withAddedDependencies(Map additions, @Nullable String scope, @Nullable String groupName); + + /** + * Pin transitive dependencies using the strategy appropriate for this file's + * package manager. For pyproject.toml: uv uses {@code [tool.uv].constraint-dependencies}, + * PDM uses {@code [tool.pdm.overrides]}, and other managers add a direct dependency. + * For requirements.txt: appends the dependency. + * + * @param pins package name → version constraint + */ + PythonDependencyFile withPinnedTransitiveDependencies(Map pins); + + /** + * Remove dependencies from the specified scope. + * + * @param packageNames package names to remove + * @param scope the TOML scope, or {@code null} for the default ({@code [project].dependencies}) + * @param groupName required for {@code "project.optional-dependencies"} or {@code "dependency-groups"} + */ + PythonDependencyFile withRemovedDependencies(Set packageNames, @Nullable String scope, @Nullable String groupName); + + /** + * Change a dependency to a different package, searching all scopes. + * + * @param oldPackageName the current package name + * @param newPackageName the new package name + * @param newVersion optional new version constraint, or {@code null} to preserve the original + */ + PythonDependencyFile withChangedDependency(String oldPackageName, String newPackageName, @Nullable String newVersion); + + /** + * Add search result markers for vulnerable dependencies. + * + * @param packageMessages package name → vulnerability description message + */ + PythonDependencyFile withDependencySearchMarkers(Map packageMessages, ExecutionContext ctx); + + /** + * Post-process the modified source file, e.g. regenerate lock files. + * Called by recipes after a trait method modifies the tree. + * The default implementation returns the tree unchanged. + * + * @param ctx the execution context + * @return the post-processed source file + */ + default SourceFile afterModification(ExecutionContext ctx) { + return getTree(); + } + + /** + * Rewrite a PEP 508 dependency spec with a new version constraint. + * Preserves extras and environment markers. The version is normalized + * via {@link PyProjectHelper#normalizeVersionConstraint(String)}, + * so both {@code "2.31.0"} and {@code ">=2.31.0"} are accepted. + */ + static String rewritePep508Spec(String spec, String packageName, String newVersion) { + int nameEnd = packageName.length(); + StringBuilder sb = new StringBuilder(packageName); + + // Preserve extras like [security] + if (nameEnd < spec.length() && spec.charAt(nameEnd) == '[') { + int extrasEnd = spec.indexOf(']', nameEnd); + if (extrasEnd >= 0) { + extrasEnd++; + sb.append(spec, nameEnd, extrasEnd); + nameEnd = extrasEnd; + } + } + + sb.append(PyProjectHelper.normalizeVersionConstraint(newVersion)); + + // Preserve environment markers (everything after ';') + int semiIdx = spec.indexOf(';', nameEnd); + if (semiIdx >= 0) { + sb.append(spec.substring(semiIdx)); + } + + return sb.toString(); + } + + /** + * Look up a value in a map by normalizing the lookup key per PEP 503. + * This allows callers to pass non-normalized package names. + */ + static @Nullable String getByNormalizedName(Map map, String name) { + String normalized = PythonResolutionResult.normalizeName(name); + for (Map.Entry entry : map.entrySet()) { + if (PythonResolutionResult.normalizeName(entry.getKey()).equals(normalized)) { + return entry.getValue(); + } + } + return null; + } + + /** + * Update the resolved dependency versions in a marker to reflect version changes. + * Returns the same marker if no changes were needed. + */ + static PythonResolutionResult updateResolvedVersions( + PythonResolutionResult marker, Map versionUpdates) { + List resolved = marker.getResolvedDependencies(); + List updated = new ArrayList<>(resolved.size()); + boolean changed = false; + for (PythonResolutionResult.ResolvedDependency dep : resolved) { + String normalizedName = PythonResolutionResult.normalizeName(dep.getName()); + String newVersion = versionUpdates.get(normalizedName); + if (newVersion != null && !newVersion.equals(dep.getVersion())) { + updated.add(dep.withVersion(newVersion)); + changed = true; + } else { + updated.add(dep); + } + } + return changed ? marker.withResolvedDependencies(updated) : marker; + } + + class Matcher extends SimpleTraitMatcher { + private final RequirementsFile.Matcher reqMatcher = new RequirementsFile.Matcher(); + private final PyProjectFile.Matcher pyprojectMatcher = new PyProjectFile.Matcher(); + private final PipfileFile.Matcher pipfileMatcher = new PipfileFile.Matcher(); + + @Override + protected @Nullable PythonDependencyFile test(Cursor cursor) { + PythonDependencyFile r = reqMatcher.test(cursor); + if (r != null) { + return r; + } + r = pyprojectMatcher.test(cursor); + if (r != null) { + return r; + } + return pipfileMatcher.test(cursor); + } + } +} diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java new file mode 100644 index 0000000000..c9644c1a9f --- /dev/null +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java @@ -0,0 +1,241 @@ +/* + * Copyright 2026 the original author or authors. + * + * Moderne Proprietary. Only for use by Moderne customers under the terms of a commercial contract. + */ +package org.openrewrite.python.trait; + +import lombok.Value; +import org.jspecify.annotations.Nullable; +import org.openrewrite.Cursor; +import org.openrewrite.ExecutionContext; +import org.openrewrite.python.RequirementsTxtParser; +import org.openrewrite.python.internal.PyProjectHelper; +import org.openrewrite.text.Find; +import org.openrewrite.python.marker.PythonResolutionResult; +import org.openrewrite.text.PlainText; +import org.openrewrite.trait.SimpleTraitMatcher; + +import java.util.*; +import java.util.regex.Pattern; + +@Value +public class RequirementsFile implements PythonDependencyFile { + private static final RequirementsTxtParser PARSER = new RequirementsTxtParser(); + private static final Pattern SCOPE_PATTERN = Pattern.compile("requirements(?:-([\\w-]+))?\\.(?:txt|in)"); + + Cursor cursor; + PythonResolutionResult marker; + + /** + * Check whether this file matches the given scope. + *
    + *
  • {@code null} → matches all requirements files
  • + *
  • {@code ""} (empty) → matches only {@code requirements.txt} / {@code requirements.in}
  • + *
  • {@code "dev"} → matches only {@code requirements-dev.txt} / {@code requirements-dev.in}
  • + *
+ */ + private boolean matchesScope(@Nullable String scope) { + if (scope == null) { + return true; + } + String filename = getTree().getSourcePath().getFileName().toString(); + java.util.regex.Matcher m = SCOPE_PATTERN.matcher(filename); + if (!m.matches()) { + return false; + } + String fileSuffix = m.group(1); // null for requirements.txt, "dev" for requirements-dev.txt + if (scope.isEmpty()) { + return fileSuffix == null; + } + return scope.equals(fileSuffix); + } + + @Override + public RequirementsFile withUpgradedVersions(Map upgrades, @Nullable String scope, @Nullable String groupName) { + if (!matchesScope(scope)) { + return this; + } + PlainText pt = (PlainText) getTree(); + String text = pt.getText(); + String[] lines = text.split("\n", -1); + boolean changed = false; + + for (int i = 0; i < lines.length; i++) { + String line = lines[i]; + String trimmed = line.trim(); + if (trimmed.isEmpty() || trimmed.startsWith("#") || trimmed.startsWith("-")) { + continue; + } + + String packageName = PyProjectHelper.extractPackageName(trimmed); + if (packageName == null) { + continue; + } + + String fixVersion = PythonDependencyFile.getByNormalizedName(upgrades, packageName); + if (fixVersion == null) { + continue; + } + + String newSpec = PythonDependencyFile.rewritePep508Spec(trimmed, packageName, fixVersion); + if (!newSpec.equals(trimmed)) { + // Preserve leading whitespace from the original line + int leadingWs = 0; + while (leadingWs < line.length() && Character.isWhitespace(line.charAt(leadingWs))) { + leadingWs++; + } + lines[i] = line.substring(0, leadingWs) + newSpec; + changed = true; + } + } + + if (changed) { + PythonResolutionResult updatedMarker = PythonDependencyFile.updateResolvedVersions(marker, upgrades); + PlainText newPt = pt.withText(String.join("\n", lines)); + newPt = newPt.withMarkers(newPt.getMarkers() + .removeByType(PythonResolutionResult.class) + .addIfAbsent(updatedMarker)); + return new RequirementsFile(new Cursor(cursor.getParentOrThrow(), newPt), updatedMarker); + } + return this; + } + + @Override + public RequirementsFile withAddedDependencies(Map additions, @Nullable String scope, @Nullable String groupName) { + if (!matchesScope(scope)) { + return this; + } + PlainText pt = (PlainText) getTree(); + String text = pt.getText(); + String[] lines = text.split("\n", -1); + + Set existingPackages = new HashSet<>(); + for (String line : lines) { + String pkg = PyProjectHelper.extractPackageName(line.trim()); + if (pkg != null) { + existingPackages.add(PythonResolutionResult.normalizeName(pkg)); + } + } + + StringBuilder sb = new StringBuilder(text); + boolean changed = false; + for (Map.Entry entry : additions.entrySet()) { + if (!existingPackages.contains(PythonResolutionResult.normalizeName(entry.getKey()))) { + sb.append("\n").append(entry.getKey()).append(PyProjectHelper.normalizeVersionConstraint(entry.getValue())); + changed = true; + } + } + + if (changed) { + PythonResolutionResult updatedMarker = PythonDependencyFile.updateResolvedVersions(marker, additions); + PlainText newPt = pt.withText(sb.toString()); + newPt = newPt.withMarkers(newPt.getMarkers() + .removeByType(PythonResolutionResult.class) + .addIfAbsent(updatedMarker)); + return new RequirementsFile(new Cursor(cursor.getParentOrThrow(), newPt), updatedMarker); + } + return this; + } + + @Override + public RequirementsFile withRemovedDependencies(Set packageNames, @Nullable String scope, @Nullable String groupName) { + if (!matchesScope(scope)) { + return this; + } + Set normalizedNames = new HashSet<>(); + for (String name : packageNames) { + normalizedNames.add(PythonResolutionResult.normalizeName(name)); + } + PlainText pt = (PlainText) getTree(); + String[] lines = pt.getText().split("\n", -1); + List kept = new ArrayList<>(); + boolean changed = false; + + for (String line : lines) { + String pkg = PyProjectHelper.extractPackageName(line.trim()); + if (pkg != null && normalizedNames.contains(PythonResolutionResult.normalizeName(pkg))) { + changed = true; + } else { + kept.add(line); + } + } + + if (changed) { + PlainText newPt = pt.withText(String.join("\n", kept)); + return new RequirementsFile(new Cursor(cursor.getParentOrThrow(), newPt), marker); + } + return this; + } + + @Override + public RequirementsFile withChangedDependency(String oldPackageName, String newPackageName, @Nullable String newVersion) { + PlainText pt = (PlainText) getTree(); + String[] lines = pt.getText().split("\n", -1); + boolean changed = false; + String normalizedOld = PythonResolutionResult.normalizeName(oldPackageName); + + for (int i = 0; i < lines.length; i++) { + String trimmed = lines[i].trim(); + String pkg = PyProjectHelper.extractPackageName(trimmed); + if (pkg != null && PythonResolutionResult.normalizeName(pkg).equals(normalizedOld)) { + // Preserve leading whitespace + int leadingWs = 0; + while (leadingWs < lines[i].length() && Character.isWhitespace(lines[i].charAt(leadingWs))) { + leadingWs++; + } + String newSpec; + if (newVersion != null) { + newSpec = newPackageName + PyProjectHelper.normalizeVersionConstraint(newVersion); + } else { + // Replace just the name, keep the rest + newSpec = newPackageName + trimmed.substring(pkg.length()); + } + lines[i] = lines[i].substring(0, leadingWs) + newSpec; + changed = true; + } + } + + if (changed) { + PlainText newPt = pt.withText(String.join("\n", lines)); + return new RequirementsFile(new Cursor(cursor.getParentOrThrow(), newPt), marker); + } + return this; + } + + @Override + public RequirementsFile withPinnedTransitiveDependencies(Map pins) { + return withAddedDependencies(pins, null, null); + } + + @Override + public RequirementsFile withDependencySearchMarkers(Map packageMessages, ExecutionContext ctx) { + PlainText result = (PlainText) getTree(); + for (Map.Entry entry : packageMessages.entrySet()) { + Find find = new Find(entry.getKey(), null, false, null, null, null, null, null); + result = (PlainText) find.getVisitor().visitNonNull(result, ctx); + } + if (result != getTree()) { + return new RequirementsFile(new Cursor(cursor.getParentOrThrow(), result), marker); + } + return this; + } + + public static class Matcher extends SimpleTraitMatcher { + @Override + protected @Nullable RequirementsFile test(Cursor cursor) { + Object value = cursor.getValue(); + if (value instanceof PlainText) { + PlainText pt = (PlainText) value; + if (PARSER.accept(pt.getSourcePath())) { + PythonResolutionResult marker = pt.getMarkers() + .findFirst(PythonResolutionResult.class).orElse(null); + if (marker != null) { + return new RequirementsFile(cursor, marker); + } + } + } + return null; + } + } +} diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/package-info.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/package-info.java new file mode 100644 index 0000000000..b7bf061317 --- /dev/null +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2024 the original author or authors. + *

+ * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://docs.moderne.io/licensing/moderne-source-available-license + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +@NullMarked +@NonNullFields +package org.openrewrite.python.trait; + +import org.jspecify.annotations.NullMarked; +import org.openrewrite.internal.lang.NonNullFields; diff --git a/rewrite-python/src/test/java/org/openrewrite/python/AddDependencyTest.java b/rewrite-python/src/test/java/org/openrewrite/python/AddDependencyTest.java index 166746d851..081f92c248 100644 --- a/rewrite-python/src/test/java/org/openrewrite/python/AddDependencyTest.java +++ b/rewrite-python/src/test/java/org/openrewrite/python/AddDependencyTest.java @@ -341,6 +341,43 @@ void addDependencyWithBareVersion() { ); } + @Test + void addDependencyToRequirementsTxt() { + rewriteRun( + spec -> spec.recipe(new AddDependency("flask", ">=2.0", null, null)), + requirementsTxt( + "requests>=2.28.0", + "requests>=2.28.0\nflask>=2.0" + ) + ); + } + + @Test + void skipWhenAlreadyPresentInRequirementsTxt() { + rewriteRun( + spec -> spec.recipe(new AddDependency("requests", null, null, null)), + requirementsTxt("requests>=2.28.0") + ); + } + + @Test + void addDependencyToPipfile() { + rewriteRun( + spec -> spec.recipe(new AddDependency("flask", ">=2.0", null, null)), + pipfile( + """ + [packages] + requests = ">=2.28.0" + """, + """ + [packages] + requests = ">=2.28.0" + flask = ">=2.0" + """ + ) + ); + } + @Test void validateRequiresGroupName() { var recipe = new AddDependency("pytest", null, "project.optional-dependencies", null); diff --git a/rewrite-python/src/test/java/org/openrewrite/python/ChangeDependencyTest.java b/rewrite-python/src/test/java/org/openrewrite/python/ChangeDependencyTest.java index fd142c38d9..1d000fd899 100644 --- a/rewrite-python/src/test/java/org/openrewrite/python/ChangeDependencyTest.java +++ b/rewrite-python/src/test/java/org/openrewrite/python/ChangeDependencyTest.java @@ -18,7 +18,7 @@ import org.junit.jupiter.api.Test; import org.openrewrite.test.RewriteTest; -import static org.openrewrite.python.Assertions.pyproject; +import static org.openrewrite.python.Assertions.*; class ChangeDependencyTest implements RewriteTest { @@ -171,4 +171,53 @@ void renameAcrossScopes() { ) ); } + + @Test + void renamePackageInRequirementsTxt() { + rewriteRun( + spec -> spec.recipe(new ChangeDependency("requests", "httpx", null)), + requirementsTxt( + "requests>=2.28.0\nclick>=8.0", + "httpx>=2.28.0\nclick>=8.0" + ) + ); + } + + @Test + void renameWithNewVersionInRequirementsTxt() { + rewriteRun( + spec -> spec.recipe(new ChangeDependency("requests", "httpx", ">=0.24.0")), + requirementsTxt( + "requests>=2.28.0\nclick>=8.0", + "httpx>=0.24.0\nclick>=8.0" + ) + ); + } + + @Test + void skipWhenNotFoundInRequirementsTxt() { + rewriteRun( + spec -> spec.recipe(new ChangeDependency("flask", "quart", null)), + requirementsTxt("requests>=2.28.0") + ); + } + + @Test + void renamePackageInPipfile() { + rewriteRun( + spec -> spec.recipe(new ChangeDependency("requests", "httpx", ">=0.24.0")), + pipfile( + """ + [packages] + requests = ">=2.28.0" + click = ">=8.0" + """, + """ + [packages] + httpx = ">=0.24.0" + click = ">=8.0" + """ + ) + ); + } } diff --git a/rewrite-python/src/test/java/org/openrewrite/python/RemoveDependencyTest.java b/rewrite-python/src/test/java/org/openrewrite/python/RemoveDependencyTest.java index 9bfbcdc7ba..25ad1e1534 100644 --- a/rewrite-python/src/test/java/org/openrewrite/python/RemoveDependencyTest.java +++ b/rewrite-python/src/test/java/org/openrewrite/python/RemoveDependencyTest.java @@ -21,8 +21,7 @@ import java.nio.file.Path; -import static org.openrewrite.python.Assertions.pyproject; -import static org.openrewrite.python.Assertions.uv; +import static org.openrewrite.python.Assertions.*; class RemoveDependencyTest implements RewriteTest { @@ -261,4 +260,41 @@ void removeFromDependencyGroup() { ) ); } + + @Test + void removeDependencyFromRequirementsTxt() { + rewriteRun( + spec -> spec.recipe(new RemoveDependency("click", null, null)), + requirementsTxt( + "requests>=2.28.0\nclick>=8.0", + "requests>=2.28.0" + ) + ); + } + + @Test + void skipWhenNotPresentInRequirementsTxt() { + rewriteRun( + spec -> spec.recipe(new RemoveDependency("flask", null, null)), + requirementsTxt("requests>=2.28.0") + ); + } + + @Test + void removeDependencyFromPipfile() { + rewriteRun( + spec -> spec.recipe(new RemoveDependency("click", null, null)), + pipfile( + """ + [packages] + requests = ">=2.28.0" + click = ">=8.0" + """, + """ + [packages] + requests = ">=2.28.0" + """ + ) + ); + } } diff --git a/rewrite-python/src/test/java/org/openrewrite/python/UpgradeDependencyVersionTest.java b/rewrite-python/src/test/java/org/openrewrite/python/UpgradeDependencyVersionTest.java index 0ec4c7b48c..641fb2ae5a 100644 --- a/rewrite-python/src/test/java/org/openrewrite/python/UpgradeDependencyVersionTest.java +++ b/rewrite-python/src/test/java/org/openrewrite/python/UpgradeDependencyVersionTest.java @@ -21,8 +21,7 @@ import java.nio.file.Path; -import static org.openrewrite.python.Assertions.pyproject; -import static org.openrewrite.python.Assertions.uv; +import static org.openrewrite.python.Assertions.*; class UpgradeDependencyVersionTest implements RewriteTest { @@ -285,4 +284,42 @@ void changeVersionInDependencyGroup() { ) ); } + + @Test + void changeVersionInRequirementsTxt() { + rewriteRun( + spec -> spec.recipe(new UpgradeDependencyVersion("requests", ">=2.31.0", null, null)), + requirementsTxt( + "requests>=2.28.0\nclick>=8.0", + "requests>=2.31.0\nclick>=8.0" + ) + ); + } + + @Test + void skipWhenNotPresentInRequirementsTxt() { + rewriteRun( + spec -> spec.recipe(new UpgradeDependencyVersion("flask", ">=3.0", null, null)), + requirementsTxt("requests>=2.28.0") + ); + } + + @Test + void changeVersionInPipfile() { + rewriteRun( + spec -> spec.recipe(new UpgradeDependencyVersion("requests", ">=2.31.0", null, null)), + pipfile( + """ + [packages] + requests = ">=2.28.0" + click = ">=8.0" + """, + """ + [packages] + requests = ">=2.31.0" + click = ">=8.0" + """ + ) + ); + } } diff --git a/rewrite-python/src/test/java/org/openrewrite/python/UpgradeTransitiveDependencyVersionTest.java b/rewrite-python/src/test/java/org/openrewrite/python/UpgradeTransitiveDependencyVersionTest.java index b24883eea8..7c13b7c06d 100644 --- a/rewrite-python/src/test/java/org/openrewrite/python/UpgradeTransitiveDependencyVersionTest.java +++ b/rewrite-python/src/test/java/org/openrewrite/python/UpgradeTransitiveDependencyVersionTest.java @@ -21,8 +21,7 @@ import java.nio.file.Path; -import static org.openrewrite.python.Assertions.pyproject; -import static org.openrewrite.python.Assertions.uv; +import static org.openrewrite.python.Assertions.*; class UpgradeTransitiveDependencyVersionTest implements RewriteTest { @@ -494,4 +493,35 @@ void skipFallbackWhenDirectDependency() { ) ); } + + @Test + void addTransitivePinToPipfile() { + rewriteRun( + spec -> spec.recipe(new UpgradeTransitiveDependencyVersion("certifi", ">=2023.7.22")), + pipfile( + """ + [packages] + requests = ">=2.28.0" + """, + """ + [packages] + requests = ">=2.28.0" + certifi = ">=2023.7.22" + """ + ) + ); + } + + @Test + void skipDirectDependencyInPipfile() { + rewriteRun( + spec -> spec.recipe(new UpgradeTransitiveDependencyVersion("requests", ">=2.31.0")), + pipfile( + """ + [packages] + requests = ">=2.28.0" + """ + ) + ); + } } diff --git a/rewrite-python/src/test/java/org/openrewrite/python/trait/PipfileFileTest.java b/rewrite-python/src/test/java/org/openrewrite/python/trait/PipfileFileTest.java new file mode 100644 index 0000000000..95fad0124d --- /dev/null +++ b/rewrite-python/src/test/java/org/openrewrite/python/trait/PipfileFileTest.java @@ -0,0 +1,297 @@ +/* + * Copyright 2026 the original author or authors. + *

+ * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://docs.moderne.io/licensing/moderne-source-available-license + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.python.trait; + +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.openrewrite.*; +import org.openrewrite.marker.Markers; +import org.openrewrite.marker.SearchResult; +import org.openrewrite.python.marker.PythonResolutionResult; +import org.openrewrite.python.marker.PythonResolutionResult.Dependency; +import org.openrewrite.python.marker.PythonResolutionResult.ResolvedDependency; +import org.openrewrite.test.RewriteTest; +import org.openrewrite.toml.TomlParser; +import org.openrewrite.toml.tree.Toml; + +import java.nio.file.Paths; +import java.util.*; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.openrewrite.Tree.randomId; + +class PipfileFileTest implements RewriteTest { + + private static PythonResolutionResult createMarker(List dependencies) { + return new PythonResolutionResult( + randomId(), null, null, null, null, + "Pipfile", null, null, + Collections.emptyList(), dependencies, + Collections.emptyMap(), Collections.emptyMap(), + Collections.emptyList(), Collections.emptyList(), + Collections.emptyList(), PythonResolutionResult.PackageManager.Pipenv, null + ); + } + + private static Toml.Document parsePipfile(String content, PythonResolutionResult marker) { + TomlParser parser = new TomlParser(); + Parser.Input input = Parser.Input.fromString(Paths.get("Pipfile"), content); + List parsed = parser.parseInputs( + Collections.singletonList(input), null, + new InMemoryExecutionContext(Throwable::printStackTrace) + ).collect(Collectors.toList()); + Toml.Document doc = (Toml.Document) parsed.get(0); + return doc.withMarkers(doc.getMarkers().addIfAbsent(marker)); + } + + private static Cursor rootCursor(Object value) { + return new Cursor(new Cursor(null, Cursor.ROOT_VALUE), value); + } + + private static PipfileFile trait(Toml.Document doc, PythonResolutionResult marker) { + return new PipfileFile(rootCursor(doc), marker); + } + + @Nested + class MatcherTest { + @Test + void matchesPipfile() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + Toml.Document doc = parsePipfile("[packages]\nrequests = \">=2.28.0\"", marker); + + PythonDependencyFile.Matcher matcher = new PythonDependencyFile.Matcher(); + PythonDependencyFile result = matcher.test(rootCursor(doc)); + + assertThat(result).isNotNull(); + assertThat(result).isInstanceOf(PipfileFile.class); + } + + @Test + void doesNotMatchWithoutMarker() { + TomlParser parser = new TomlParser(); + Parser.Input input = Parser.Input.fromString(Paths.get("Pipfile"), "[packages]\nrequests = \"*\""); + Toml.Document doc = (Toml.Document) parser.parseInputs( + Collections.singletonList(input), null, + new InMemoryExecutionContext(Throwable::printStackTrace) + ).collect(Collectors.toList()).get(0); + + PipfileFile.Matcher matcher = new PipfileFile.Matcher(); + assertThat(matcher.test(rootCursor(doc))).isNull(); + } + + @Test + void doesNotMatchPyprojectToml() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + TomlParser parser = new TomlParser(); + Parser.Input input = Parser.Input.fromString(Paths.get("pyproject.toml"), "[project]\nname = \"test\""); + Toml.Document doc = (Toml.Document) parser.parseInputs( + Collections.singletonList(input), null, + new InMemoryExecutionContext(Throwable::printStackTrace) + ).collect(Collectors.toList()).get(0); + doc = doc.withMarkers(doc.getMarkers().addIfAbsent(marker)); + + PipfileFile.Matcher matcher = new PipfileFile.Matcher(); + assertThat(matcher.test(rootCursor(doc))).isNull(); + } + } + + @Nested + class UpgradeVersionTest { + @Test + void upgradeSimpleVersion() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + Toml.Document doc = parsePipfile("[packages]\nrequests = \">=2.28.0\"", marker); + PipfileFile t = trait(doc, marker); + + PipfileFile upgraded = t.withUpgradedVersions( + Collections.singletonMap("requests", ">=2.31.0"), null, null); + + String printed = ((Toml.Document) upgraded.getTree()).printAll(); + assertThat(printed).contains("requests = \">=2.31.0\""); + } + + @Test + void upgradeInDevPackages() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + Toml.Document doc = parsePipfile( + "[packages]\nflask = \"*\"\n\n[dev-packages]\npytest = \">=7.0\"", marker); + PipfileFile t = trait(doc, marker); + + PipfileFile upgraded = t.withUpgradedVersions( + Collections.singletonMap("pytest", ">=8.0"), "dev-packages", null); + + String printed = ((Toml.Document) upgraded.getTree()).printAll(); + assertThat(printed).contains("pytest = \">=8.0\""); + assertThat(printed).contains("flask = \"*\""); + } + + @Test + void noOpWhenNotFound() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + Toml.Document doc = parsePipfile("[packages]\nrequests = \">=2.28.0\"", marker); + PipfileFile t = trait(doc, marker); + + PipfileFile upgraded = t.withUpgradedVersions( + Collections.singletonMap("nonexistent", ">=1.0"), null, null); + + assertThat(upgraded).isSameAs(t); + } + } + + @Nested + class AddDependencyTest { + @Test + void addToPackages() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + Toml.Document doc = parsePipfile("[packages]\nrequests = \">=2.28.0\"", marker); + PipfileFile t = trait(doc, marker); + + PipfileFile added = t.withAddedDependencies( + Collections.singletonMap("flask", ">=2.0"), "packages", null); + + String printed = ((Toml.Document) added.getTree()).printAll(); + assertThat(printed).contains("flask = \">=2.0\""); + assertThat(printed).contains("requests = \">=2.28.0\""); + } + + @Test + void addToDevPackages() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + Toml.Document doc = parsePipfile( + "[packages]\nrequests = \"*\"\n\n[dev-packages]\npytest = \">=7.0\"", marker); + PipfileFile t = trait(doc, marker); + + PipfileFile added = t.withAddedDependencies( + Collections.singletonMap("mypy", ">=1.0"), "dev-packages", null); + + String printed = ((Toml.Document) added.getTree()).printAll(); + assertThat(printed).contains("mypy = \">=1.0\""); + } + + @Test + void noOpWhenAlreadyPresent() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + Toml.Document doc = parsePipfile("[packages]\nrequests = \">=2.28.0\"", marker); + PipfileFile t = trait(doc, marker); + + PipfileFile added = t.withAddedDependencies( + Collections.singletonMap("requests", ">=2.31.0"), "packages", null); + + assertThat(added).isSameAs(t); + } + } + + @Nested + class RemoveDependencyTest { + @Test + void removeFromPackages() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + Toml.Document doc = parsePipfile("[packages]\nrequests = \">=2.28.0\"\nflask = \"*\"", marker); + PipfileFile t = trait(doc, marker); + + PipfileFile removed = t.withRemovedDependencies( + Collections.singleton("flask"), "packages", null); + + String printed = ((Toml.Document) removed.getTree()).printAll(); + assertThat(printed).contains("requests = \">=2.28.0\""); + assertThat(printed).doesNotContain("flask"); + } + + @Test + void noOpWhenNotFound() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + Toml.Document doc = parsePipfile("[packages]\nrequests = \">=2.28.0\"", marker); + PipfileFile t = trait(doc, marker); + + PipfileFile removed = t.withRemovedDependencies( + Collections.singleton("nonexistent"), "packages", null); + + assertThat(removed).isSameAs(t); + } + } + + @Nested + class ChangeDependencyTest { + @Test + void renamePackage() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + Toml.Document doc = parsePipfile("[packages]\nrequests = \">=2.28.0\"", marker); + PipfileFile t = trait(doc, marker); + + PipfileFile changed = t.withChangedDependency("requests", "httpx", null); + + String printed = ((Toml.Document) changed.getTree()).printAll(); + assertThat(printed).contains("httpx = \">=2.28.0\""); + assertThat(printed).doesNotContain("requests"); + } + + @Test + void renameWithNewVersion() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + Toml.Document doc = parsePipfile("[packages]\nrequests = \">=2.28.0\"", marker); + PipfileFile t = trait(doc, marker); + + PipfileFile changed = t.withChangedDependency("requests", "httpx", ">=0.24.0"); + + String printed = ((Toml.Document) changed.getTree()).printAll(); + assertThat(printed).contains("httpx = \">=0.24.0\""); + } + } + + @Nested + class SearchMarkersTest { + @Test + void markVulnerableDependency() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + Toml.Document doc = parsePipfile( + "[packages]\nrequests = \">=2.28.0\"\nflask = \"*\"", marker); + PipfileFile t = trait(doc, marker); + + ExecutionContext ctx = new InMemoryExecutionContext(Throwable::printStackTrace); + PipfileFile marked = t.withDependencySearchMarkers( + Collections.singletonMap("requests", "CVE-2023-1234"), ctx); + + Toml.Document result = (Toml.Document) marked.getTree(); + boolean[] found = {false}; + new org.openrewrite.toml.TomlVisitor() { + @Override + public Toml visitKeyValue(Toml.KeyValue keyValue, Integer p) { + if (keyValue.getKey() instanceof Toml.Identifier && + "requests".equals(((Toml.Identifier) keyValue.getKey()).getName()) && + keyValue.getMarkers().findFirst(SearchResult.class).isPresent()) { + found[0] = true; + } + return keyValue; + } + }.visit(result, 0); + assertThat(found[0]).as("requests should have SearchResult marker").isTrue(); + } + + @Test + void noOpWhenNoMatch() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + Toml.Document doc = parsePipfile("[packages]\nrequests = \">=2.28.0\"", marker); + PipfileFile t = trait(doc, marker); + + ExecutionContext ctx = new InMemoryExecutionContext(Throwable::printStackTrace); + PipfileFile marked = t.withDependencySearchMarkers( + Collections.singletonMap("nonexistent", "CVE-2023-9999"), ctx); + + assertThat(marked).isSameAs(t); + } + } +} diff --git a/rewrite-python/src/test/java/org/openrewrite/python/trait/PythonDependencyFileTest.java b/rewrite-python/src/test/java/org/openrewrite/python/trait/PythonDependencyFileTest.java new file mode 100644 index 0000000000..18aa010735 --- /dev/null +++ b/rewrite-python/src/test/java/org/openrewrite/python/trait/PythonDependencyFileTest.java @@ -0,0 +1,751 @@ +/* + * Copyright 2026 the original author or authors. + *

+ * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://docs.moderne.io/licensing/moderne-source-available-license + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.python.trait; + +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.openrewrite.*; +import org.openrewrite.marker.Markers; +import org.openrewrite.marker.SearchResult; +import org.openrewrite.python.marker.PythonResolutionResult; +import org.openrewrite.python.marker.PythonResolutionResult.Dependency; +import org.openrewrite.python.marker.PythonResolutionResult.ResolvedDependency; +import org.openrewrite.test.RewriteTest; +import org.openrewrite.text.PlainText; +import org.openrewrite.toml.TomlParser; +import org.openrewrite.toml.tree.Toml; + +import java.nio.file.Paths; +import java.util.*; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.openrewrite.Tree.randomId; +import static org.openrewrite.python.Assertions.pyproject; +import static org.openrewrite.python.Assertions.requirementsTxt; + +class PythonDependencyFileTest implements RewriteTest { + + // region Helper methods + + private static PythonResolutionResult createMarker(List dependencies, + List resolved) { + return new PythonResolutionResult( + randomId(), "test-project", "1.0.0", null, null, + ".", null, null, + Collections.emptyList(), dependencies, + Collections.emptyMap(), Collections.emptyMap(), + Collections.emptyList(), Collections.emptyList(), + resolved, null, null + ); + } + + private static Toml.Document parseToml(String content, PythonResolutionResult marker) { + TomlParser parser = new TomlParser(); + Parser.Input input = Parser.Input.fromString(Paths.get("pyproject.toml"), content); + List parsed = parser.parseInputs( + Collections.singletonList(input), null, + new InMemoryExecutionContext(Throwable::printStackTrace) + ).collect(Collectors.toList()); + Toml.Document doc = (Toml.Document) parsed.get(0); + return doc.withMarkers(doc.getMarkers().addIfAbsent(marker)); + } + + private static PlainText createRequirementsTxt(String content, PythonResolutionResult marker) { + return new PlainText( + randomId(), Paths.get("requirements.txt"), + Markers.EMPTY.addIfAbsent(marker), + "UTF-8", false, null, null, content, null + ); + } + + private static Cursor rootCursor(Object value) { + return new Cursor(new Cursor(null, Cursor.ROOT_VALUE), value); + } + + private static PyProjectFile pyProjectTrait(Toml.Document doc, PythonResolutionResult marker) { + return new PyProjectFile(rootCursor(doc), marker); + } + + private static RequirementsFile requirementsTrait(PlainText pt, PythonResolutionResult marker) { + return new RequirementsFile(rootCursor(pt), marker); + } + + /** + * A recipe that applies {@link PythonDependencyFile#withDependencySearchMarkers} via the trait matcher. + */ + private static Recipe searchMarkersRecipe(Map packageMessages) { + return RewriteTest.toRecipe(() -> new TreeVisitor() { + final PythonDependencyFile.Matcher matcher = new PythonDependencyFile.Matcher(); + + @Override + public Tree preVisit(Tree tree, ExecutionContext ctx) { + PythonDependencyFile trait = matcher.test(getCursor()); + if (trait != null) { + return trait.withDependencySearchMarkers(packageMessages, ctx).getTree(); + } + return tree; + } + }); + } + + // endregion + + @Nested + class RewritePep508SpecTest { + @Test + void simpleUpgrade() { + String result = PythonDependencyFile.rewritePep508Spec("requests>=2.28.0", "requests", "2.31.0"); + assertThat(result).isEqualTo("requests>=2.31.0"); + } + + @Test + void preservesExtras() { + String result = PythonDependencyFile.rewritePep508Spec("requests[security]>=2.28.0", "requests", "2.31.0"); + assertThat(result).isEqualTo("requests[security]>=2.31.0"); + } + + @Test + void preservesEnvironmentMarker() { + String result = PythonDependencyFile.rewritePep508Spec( + "pywin32>=300; sys_platform=='win32'", "pywin32", "306"); + assertThat(result).isEqualTo("pywin32>=306; sys_platform=='win32'"); + } + + @Test + void preservesExtrasAndMarker() { + String result = PythonDependencyFile.rewritePep508Spec( + "requests[security]>=2.28.0; python_version>='3.8'", "requests", "2.31.0"); + assertThat(result).isEqualTo("requests[security]>=2.31.0; python_version>='3.8'"); + } + + @Test + void nameOnly() { + String result = PythonDependencyFile.rewritePep508Spec("requests", "requests", "2.31.0"); + assertThat(result).isEqualTo("requests>=2.31.0"); + } + } + + @Nested + class UpdateResolvedVersionsTest { + @Test + void updatesMatchingVersions() { + ResolvedDependency requests = new ResolvedDependency("requests", "2.28.0", null, null); + ResolvedDependency flask = new ResolvedDependency("flask", "2.0.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), Arrays.asList(requests, flask)); + + Map updates = new HashMap<>(); + updates.put("requests", "2.31.0"); + + PythonResolutionResult updated = PythonDependencyFile.updateResolvedVersions(marker, updates); + + assertThat(updated.getResolvedDependencies()).hasSize(2); + assertThat(updated.getResolvedDependencies().get(0).getVersion()).isEqualTo("2.31.0"); + assertThat(updated.getResolvedDependencies().get(1).getVersion()).isEqualTo("2.0.0"); + } + + @Test + void returnsOriginalWhenNoChanges() { + ResolvedDependency requests = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), Collections.singletonList(requests)); + + Map updates = new HashMap<>(); + updates.put("nonexistent", "1.0.0"); + + PythonResolutionResult updated = PythonDependencyFile.updateResolvedVersions(marker, updates); + + assertThat(updated).isSameAs(marker); + } + + @Test + void returnsOriginalWhenVersionUnchanged() { + ResolvedDependency requests = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), Collections.singletonList(requests)); + + Map updates = new HashMap<>(); + updates.put("requests", "2.28.0"); + + PythonResolutionResult updated = PythonDependencyFile.updateResolvedVersions(marker, updates); + + assertThat(updated).isSameAs(marker); + } + } + + @Nested + class MatcherTest { + @Test + void matchesPyProjectToml() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.31.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), Collections.singletonList(resolved)); + Toml.Document doc = parseToml("[project]\nname = \"test\"\ndependencies = [\"requests>=2.28.0\"]", marker); + + PythonDependencyFile.Matcher matcher = new PythonDependencyFile.Matcher(); + PythonDependencyFile result = matcher.test(rootCursor(doc)); + + assertThat(result).isNotNull(); + assertThat(result).isInstanceOf(PyProjectFile.class); + } + + @Test + void matchesRequirementsTxt() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.31.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), Collections.singletonList(resolved)); + PlainText pt = createRequirementsTxt("requests>=2.28.0", marker); + + PythonDependencyFile.Matcher matcher = new PythonDependencyFile.Matcher(); + PythonDependencyFile result = matcher.test(rootCursor(pt)); + + assertThat(result).isNotNull(); + assertThat(result).isInstanceOf(RequirementsFile.class); + } + + @Test + void doesNotMatchWithoutMarker() { + TomlParser parser = new TomlParser(); + Parser.Input input = Parser.Input.fromString(Paths.get("pyproject.toml"), + "[project]\nname = \"test\""); + Toml.Document doc = (Toml.Document) parser.parseInputs( + Collections.singletonList(input), null, + new InMemoryExecutionContext(Throwable::printStackTrace) + ).collect(Collectors.toList()).get(0); + + PythonDependencyFile.Matcher matcher = new PythonDependencyFile.Matcher(); + assertThat(matcher.test(rootCursor(doc))).isNull(); + } + + @Test + void doesNotMatchNonPythonFile() { + PlainText pt = new PlainText( + randomId(), Paths.get("readme.txt"), + Markers.EMPTY, "UTF-8", false, null, null, "hello", null + ); + + PythonDependencyFile.Matcher matcher = new PythonDependencyFile.Matcher(); + assertThat(matcher.test(rootCursor(pt))).isNull(); + } + } + + @Nested + class PyProjectFileTest { + + @Test + void upgradesDependencyVersion() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + Dependency dep = new Dependency("requests", ">=2.28.0", null, null, resolved); + PythonResolutionResult marker = createMarker(Collections.singletonList(dep), + Collections.singletonList(resolved)); + + String toml = "[project]\nname = \"test\"\ndependencies = [\n \"requests>=2.28.0\",\n]"; + Toml.Document doc = parseToml(toml, marker); + PyProjectFile trait = pyProjectTrait(doc, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + PyProjectFile upgraded = trait.withUpgradedVersions(upgrades, null, null); + + Toml.Document result = (Toml.Document) upgraded.getTree(); + String printed = result.printAll(); + assertThat(printed).contains("\"requests>=2.31.0\""); + assertThat(printed).doesNotContain("\"requests>=2.28.0\""); + } + + @Test + void upgradePreservesExtras() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + String toml = "[project]\nname = \"test\"\ndependencies = [\n \"requests[security]>=2.28.0\",\n]"; + Toml.Document doc = parseToml(toml, marker); + PyProjectFile trait = pyProjectTrait(doc, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + PyProjectFile upgraded = trait.withUpgradedVersions(upgrades, null, null); + + String printed = ((Toml.Document) upgraded.getTree()).printAll(); + assertThat(printed).contains("\"requests[security]>=2.31.0\""); + } + + @Test + void upgradeNoOpWhenPackageNotFound() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + String toml = "[project]\nname = \"test\"\ndependencies = [\n \"requests>=2.28.0\",\n]"; + Toml.Document doc = parseToml(toml, marker); + PyProjectFile trait = pyProjectTrait(doc, marker); + + Map upgrades = Collections.singletonMap("nonexistent", "1.0.0"); + PyProjectFile upgraded = trait.withUpgradedVersions(upgrades, null, null); + + assertThat(upgraded).isSameAs(trait); + } + + @Test + void upgradeUpdatesResolvedVersionsInMarker() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + String toml = "[project]\nname = \"test\"\ndependencies = [\n \"requests>=2.28.0\",\n]"; + Toml.Document doc = parseToml(toml, marker); + PyProjectFile trait = pyProjectTrait(doc, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + PyProjectFile upgraded = trait.withUpgradedVersions(upgrades, null, null); + + assertThat(upgraded.getMarker().getResolvedDependencies().get(0).getVersion()).isEqualTo("2.31.0"); + } + + @Test + void searchMarkersOnVulnerableDependency() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + String toml = "[project]\nname = \"test\"\ndependencies = [\n \"requests>=2.28.0\",\n \"flask>=2.0\",\n]"; + Toml.Document doc = parseToml(toml, marker); + PyProjectFile trait = pyProjectTrait(doc, marker); + + Map vulnerabilities = Collections.singletonMap("requests", "CVE-2023-1234"); + ExecutionContext ctx = new InMemoryExecutionContext(Throwable::printStackTrace); + PyProjectFile marked = trait.withDependencySearchMarkers(vulnerabilities, ctx); + + Toml.Document result = (Toml.Document) marked.getTree(); + new org.openrewrite.toml.TomlVisitor() { + @Override + public Toml visitLiteral(Toml.Literal literal, Integer p) { + if (literal.getValue().toString().contains("requests")) { + assertThat(literal.getMarkers().findFirst(SearchResult.class)).isPresent(); + } + return literal; + } + }.visit(result, 0); + assertThat(result).isNotSameAs(doc); + } + + @Test + void searchMarkersNoOpWhenNoMatch() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + String toml = "[project]\nname = \"test\"\ndependencies = [\n \"requests>=2.28.0\",\n]"; + Toml.Document doc = parseToml(toml, marker); + PyProjectFile trait = pyProjectTrait(doc, marker); + + Map vulnerabilities = Collections.singletonMap("nonexistent", "CVE-2023-9999"); + ExecutionContext ctx = new InMemoryExecutionContext(Throwable::printStackTrace); + PyProjectFile marked = trait.withDependencySearchMarkers(vulnerabilities, ctx); + + assertThat(marked).isSameAs(trait); + } + + @Test + void upgradeMultipleDependencies() { + ResolvedDependency requests = new ResolvedDependency("requests", "2.28.0", null, null); + ResolvedDependency flask = new ResolvedDependency("flask", "2.0.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Arrays.asList(requests, flask)); + + String toml = "[project]\nname = \"test\"\ndependencies = [\n \"requests>=2.28.0\",\n \"flask>=2.0.0\",\n]"; + Toml.Document doc = parseToml(toml, marker); + PyProjectFile trait = pyProjectTrait(doc, marker); + + Map upgrades = new HashMap<>(); + upgrades.put("requests", "2.31.0"); + upgrades.put("flask", "3.0.0"); + PyProjectFile upgraded = trait.withUpgradedVersions(upgrades, null, null); + + String printed = ((Toml.Document) upgraded.getTree()).printAll(); + assertThat(printed).contains("\"requests>=2.31.0\""); + assertThat(printed).contains("\"flask>=3.0.0\""); + } + + @Test + void doesNotUpgradeDependenciesOutsideProjectSection() { + ResolvedDependency resolved = new ResolvedDependency("setuptools", "68.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + String toml = "[build-system]\nrequires = [\"setuptools>=67.0\"]\n\n[project]\nname = \"test\"\ndependencies = []"; + Toml.Document doc = parseToml(toml, marker); + PyProjectFile trait = pyProjectTrait(doc, marker); + + Map upgrades = Collections.singletonMap("setuptools", "69.0"); + PyProjectFile upgraded = trait.withUpgradedVersions(upgrades, null, null); + + // build-system is not inside [project], so it should not be upgraded + assertThat(upgraded).isSameAs(trait); + } + } + + @Nested + class RequirementsFileTest { + + @Test + void upgradesDependencyVersion() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = createRequirementsTxt("requests>=2.28.0\nflask>=2.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, null, null); + + PlainText result = (PlainText) upgraded.getTree(); + assertThat(result.getText()).contains("requests>=2.31.0"); + assertThat(result.getText()).contains("flask>=2.0"); + } + + @Test + void upgradePreservesExtras() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = createRequirementsTxt("requests[security]>=2.28.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, null, null); + + assertThat(((PlainText) upgraded.getTree()).getText()).isEqualTo("requests[security]>=2.31.0"); + } + + @Test + void upgradePreservesEnvironmentMarkers() { + ResolvedDependency resolved = new ResolvedDependency("pywin32", "300", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = createRequirementsTxt("pywin32>=300; sys_platform=='win32'", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("pywin32", "306"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, null, null); + + assertThat(((PlainText) upgraded.getTree()).getText()) + .isEqualTo("pywin32>=306; sys_platform=='win32'"); + } + + @Test + void upgradeSkipsComments() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = createRequirementsTxt("# this is a comment\nrequests>=2.28.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, null, null); + + String text = ((PlainText) upgraded.getTree()).getText(); + assertThat(text).startsWith("# this is a comment\n"); + assertThat(text).contains("requests>=2.31.0"); + } + + @Test + void upgradeSkipsFlags() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = createRequirementsTxt("-r base.txt\nrequests>=2.28.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, null, null); + + String text = ((PlainText) upgraded.getTree()).getText(); + assertThat(text).startsWith("-r base.txt\n"); + assertThat(text).contains("requests>=2.31.0"); + } + + @Test + void upgradeNoOpWhenPackageNotFound() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = createRequirementsTxt("requests>=2.28.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("nonexistent", "1.0.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, null, null); + + assertThat(upgraded).isSameAs(trait); + } + + @Test + void upgradeUpdatesResolvedVersionsInMarker() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = createRequirementsTxt("requests>=2.28.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, null, null); + + assertThat(upgraded.getMarker().getResolvedDependencies().get(0).getVersion()).isEqualTo("2.31.0"); + } + + @Test + void upgradePreservesLeadingWhitespace() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = createRequirementsTxt(" requests>=2.28.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, null, null); + + assertThat(((PlainText) upgraded.getTree()).getText()).isEqualTo(" requests>=2.31.0"); + } + + @Test + void addsDependencyToEnd() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = createRequirementsTxt("requests>=2.28.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map additions = Collections.singletonMap("flask", "3.0.0"); + RequirementsFile added = trait.withAddedDependencies(additions, null, null); + + String text = ((PlainText) added.getTree()).getText(); + assertThat(text).isEqualTo("requests>=2.28.0\nflask>=3.0.0"); + } + + @Test + void addDependencyNoOpWhenAlreadyPresent() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = createRequirementsTxt("requests>=2.28.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map additions = Collections.singletonMap("requests", "2.31.0"); + RequirementsFile added = trait.withAddedDependencies(additions, null, null); + + assertThat(added).isSameAs(trait); + } + + @Test + void searchMarkersOnVulnerableDependency() { + rewriteRun( + spec -> spec.recipe(searchMarkersRecipe( + Collections.singletonMap("requests", "CVE-2023-1234"))), + requirementsTxt( + "requests>=2.28.0\nflask>=2.0", + "~~>requests>=2.28.0\nflask>=2.0" + ) + ); + } + + @Test + void searchMarkersNoOpWhenNoMatch() { + rewriteRun( + spec -> spec.recipe(searchMarkersRecipe( + Collections.singletonMap("nonexistent", "CVE-2023-9999"))), + requirementsTxt("requests>=2.28.0") + ); + } + + @Test + void upgradeMultipleDependencies() { + ResolvedDependency requests = new ResolvedDependency("requests", "2.28.0", null, null); + ResolvedDependency flask = new ResolvedDependency("flask", "2.0.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Arrays.asList(requests, flask)); + + PlainText pt = createRequirementsTxt("requests>=2.28.0\nflask>=2.0.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = new HashMap<>(); + upgrades.put("requests", "2.31.0"); + upgrades.put("flask", "3.0.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, null, null); + + String text = ((PlainText) upgraded.getTree()).getText(); + assertThat(text).contains("requests>=2.31.0"); + assertThat(text).contains("flask>=3.0.0"); + } + + @Test + void upgradeHandlesEmptyLines() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = createRequirementsTxt("requests>=2.28.0\n\nflask>=2.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, null, null); + + String text = ((PlainText) upgraded.getTree()).getText(); + assertThat(text).isEqualTo("requests>=2.31.0\n\nflask>=2.0"); + } + } + + @Nested + class PyProjectSearchMarkersTest { + + @Test + void searchMarkersViaMatcher() { + rewriteRun( + spec -> spec.recipe(searchMarkersRecipe( + Collections.singletonMap("requests", "CVE-2023-1234"))), + pyproject( + """ + [project] + name = "test" + dependencies = [ + "requests>=2.28.0", + "flask>=2.0", + ] + """, + """ + [project] + name = "test" + dependencies = [ + ~~(CVE-2023-1234)~~>"requests>=2.28.0", + "flask>=2.0", + ] + """ + ) + ); + } + + @Test + void searchMarkersNoOpViaMatcher() { + rewriteRun( + spec -> spec.recipe(searchMarkersRecipe( + Collections.singletonMap("nonexistent", "CVE-2023-9999"))), + pyproject( + """ + [project] + name = "test" + dependencies = [ + "requests>=2.28.0", + ] + """ + ) + ); + } + } + + @Nested + class RequirementsScopeFilterTest { + + @Test + void nullScopeMatchesAllFiles() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = createRequirementsTxt("requests>=2.28.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, null, null); + + assertThat(((PlainText) upgraded.getTree()).getText()).contains("requests>=2.31.0"); + } + + @Test + void emptyScopeMatchesRootRequirementsTxt() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + // requirements.txt (no suffix) should match scope="" + PlainText pt = createRequirementsTxt("requests>=2.28.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, "", null); + + assertThat(((PlainText) upgraded.getTree()).getText()).contains("requests>=2.31.0"); + } + + @Test + void emptyScopeDoesNotMatchScopedFile() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = new PlainText( + randomId(), Paths.get("requirements-dev.txt"), + Markers.EMPTY.addIfAbsent(marker), + "UTF-8", false, null, null, "requests>=2.28.0", null + ); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, "", null); + + // scope="" should NOT match requirements-dev.txt + assertThat(upgraded).isSameAs(trait); + } + + @Test + void devScopeMatchesDevFile() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = new PlainText( + randomId(), Paths.get("requirements-dev.txt"), + Markers.EMPTY.addIfAbsent(marker), + "UTF-8", false, null, null, "requests>=2.28.0", null + ); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, "dev", null); + + assertThat(((PlainText) upgraded.getTree()).getText()).contains("requests>=2.31.0"); + } + + @Test + void devScopeDoesNotMatchRootFile() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = createRequirementsTxt("requests>=2.28.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, "dev", null); + + // scope="dev" should NOT match requirements.txt + assertThat(upgraded).isSameAs(trait); + } + } +}