From d9a0e2242db80a873965fbea764c350e95f5b62f Mon Sep 17 00:00:00 2001 From: Jente Sondervorst Date: Fri, 3 Apr 2026 12:24:14 +0200 Subject: [PATCH 01/10] initial version of the trait --- .../python/trait/PyProjectFile.java | 233 ++++++ .../python/trait/PythonDependencyFile.java | 102 +++ .../python/trait/RequirementsFile.java | 142 ++++ .../python/trait/package-info.java | 21 + .../trait/PythonDependencyFileTest.java | 661 ++++++++++++++++++ 5 files changed, 1159 insertions(+) create mode 100644 rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java create mode 100644 rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java create mode 100644 rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java create mode 100644 rewrite-python/src/main/java/org/openrewrite/python/trait/package-info.java create mode 100644 rewrite-python/src/test/java/org/openrewrite/python/trait/PythonDependencyFileTest.java diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java new file mode 100644 index 0000000000..06032f6712 --- /dev/null +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java @@ -0,0 +1,233 @@ +/* + * Copyright 2026 the original author or authors. + * + * Moderne Proprietary. Only for use by Moderne customers under the terms of a commercial contract. + */ +package org.openrewrite.python.trait; + +import lombok.Value; +import org.jspecify.annotations.Nullable; +import org.openrewrite.Cursor; +import org.openrewrite.ExecutionContext; +import org.openrewrite.Tree; +import org.openrewrite.marker.Markers; +import org.openrewrite.marker.SearchResult; +import org.openrewrite.python.internal.PyProjectHelper; +import org.openrewrite.python.marker.PythonResolutionResult; +import org.openrewrite.toml.TomlIsoVisitor; +import org.openrewrite.toml.tree.Space; +import org.openrewrite.toml.tree.Toml; +import org.openrewrite.toml.tree.TomlRightPadded; +import org.openrewrite.toml.tree.TomlType; +import org.openrewrite.trait.SimpleTraitMatcher; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +@Value +public class PyProjectFile implements PythonDependencyFile { + + Cursor cursor; + PythonResolutionResult marker; + + @Override + public PyProjectFile withUpgradedVersions(Map upgrades) { + Toml.Document doc = (Toml.Document) getTree(); + Toml.Document result = (Toml.Document) new TomlIsoVisitor>() { + @Override + public Toml.Literal visitLiteral(Toml.Literal literal, Map u) { + if (!isInsideProjectDependencies(getCursor())) { + return literal; + } + + String spec = literal.getValue().toString(); + String packageName = PyProjectHelper.extractPackageName(spec); + if (packageName == null) { + return literal; + } + + String normalizedName = PythonResolutionResult.normalizeName(packageName); + String fixVersion = u.get(normalizedName); + if (fixVersion == null) { + return literal; + } + + String newSpec = PythonDependencyFile.rewritePep508Spec(spec, packageName, fixVersion); + if (!newSpec.equals(spec)) { + return literal.withSource("\"" + newSpec + "\"").withValue(newSpec); + } + return literal; + } + }.visitNonNull(doc, upgrades); + if (result != doc) { + PythonResolutionResult updatedMarker = PythonDependencyFile.updateResolvedVersions(marker, upgrades); + result = result.withMarkers(result.getMarkers() + .removeByType(PythonResolutionResult.class) + .addIfAbsent(updatedMarker)); + return new PyProjectFile(new Cursor(cursor.getParentOrThrow(), result), updatedMarker); + } + return this; + } + + @Override + public PyProjectFile withAddedDependencies(Map additions) { + Toml.Document doc = (Toml.Document) getTree(); + Toml.Document original = doc; + String scope = transitiveConstraintScope(); + for (Map.Entry entry : additions.entrySet()) { + if (PyProjectHelper.findDependencyInScope(marker, entry.getKey(), scope, null) == null) { + String pep508 = entry.getKey() + ">=" + entry.getValue(); + doc = addDependencyToArray(doc, pep508, scope); + } + } + if (doc != original) { + PythonResolutionResult updatedMarker = PythonDependencyFile.updateResolvedVersions(marker, additions); + doc = doc.withMarkers(doc.getMarkers() + .removeByType(PythonResolutionResult.class) + .addIfAbsent(updatedMarker)); + return new PyProjectFile(new Cursor(cursor.getParentOrThrow(), doc), updatedMarker); + } + return this; + } + + /** + * Determine the TOML scope for transitive dependency constraints based on + * the package manager. + */ + private @Nullable String transitiveConstraintScope() { + PythonResolutionResult.PackageManager pm = marker.getPackageManager(); + if (pm == PythonResolutionResult.PackageManager.Uv) { + return "tool.uv.constraint-dependencies"; + } + // TODO: PDM uses [tool.pdm.overrides] (key-value, not array) — needs separate handling + return null; + } + + private static Toml.Document addDependencyToArray(Toml.Document d, String pep508, @Nullable String scope) { + return (Toml.Document) new TomlIsoVisitor() { + @Override + public Toml.Array visitArray(Toml.Array array, Integer p) { + Toml.Array a = super.visitArray(array, p); + if (!PyProjectHelper.isInsideDependencyArray(getCursor(), scope, null)) { + return a; + } + + Toml.Literal newLiteral = new Toml.Literal( + Tree.randomId(), Space.EMPTY, Markers.EMPTY, + TomlType.Primitive.String, "\"" + pep508 + "\"", pep508); + + List> existingPadded = a.getPadding().getValues(); + List> newPadded = new ArrayList<>(); + + boolean isEmpty = existingPadded.size() == 1 && + existingPadded.get(0).getElement() instanceof Toml.Empty; + if (existingPadded.isEmpty() || isEmpty) { + newPadded.add(new TomlRightPadded<>(newLiteral, Space.EMPTY, Markers.EMPTY)); + } else { + TomlRightPadded lastPadded = existingPadded.get(existingPadded.size() - 1); + boolean hasTrailingComma = lastPadded.getElement() instanceof Toml.Empty; + + if (hasTrailingComma) { + int lastRealIdx = existingPadded.size() - 2; + Toml lastRealElement = existingPadded.get(lastRealIdx).getElement(); + Toml.Literal formatted = newLiteral.withPrefix(lastRealElement.getPrefix()); + for (int i = 0; i <= lastRealIdx; i++) { + newPadded.add(existingPadded.get(i)); + } + newPadded.add(new TomlRightPadded<>(formatted, Space.EMPTY, Markers.EMPTY)); + newPadded.add(lastPadded); + } else { + Toml lastElement = lastPadded.getElement(); + Space newPrefix = lastElement.getPrefix().getWhitespace().contains("\n") ? + lastElement.getPrefix() : + Space.SINGLE_SPACE; + Toml.Literal formatted = newLiteral.withPrefix(newPrefix); + for (int i = 0; i < existingPadded.size() - 1; i++) { + newPadded.add(existingPadded.get(i)); + } + newPadded.add(lastPadded.withAfter(Space.EMPTY)); + newPadded.add(new TomlRightPadded<>(formatted, lastPadded.getAfter(), Markers.EMPTY)); + } + } + + return a.getPadding().withValues(newPadded); + } + }.visitNonNull(d, 0); + } + + @Override + public PyProjectFile withDependencySearchMarkers(Map packageMessages, ExecutionContext ctx) { + Toml.Document doc = (Toml.Document) getTree(); + Toml.Document result = (Toml.Document) new TomlIsoVisitor>() { + @Override + public Toml.Literal visitLiteral(Toml.Literal literal, Map msgs) { + if (!isInsideProjectDependencies(getCursor())) { + return literal; + } + + String spec = literal.getValue().toString(); + String packageName = PyProjectHelper.extractPackageName(spec); + if (packageName == null) { + return literal; + } + + String normalizedName = PythonResolutionResult.normalizeName(packageName); + String message = msgs.get(normalizedName); + if (message != null) { + return SearchResult.found(literal, message); + } + return literal; + } + }.visitNonNull(doc, packageMessages); + if (result != doc) { + return new PyProjectFile(new Cursor(cursor.getParentOrThrow(), result), marker); + } + return this; + } + + private static boolean isInsideProjectDependencies(Cursor cursor) { + Cursor c = cursor; + boolean inArray = false; + boolean inDependencies = false; + boolean inProject = false; + while (c != null) { + Object value = c.getValue(); + if (value instanceof Toml.Array) { + inArray = true; + } else if (value instanceof Toml.KeyValue && inArray) { + Toml.KeyValue kv = (Toml.KeyValue) value; + if (kv.getKey() instanceof Toml.Identifier && + "dependencies".equals(((Toml.Identifier) kv.getKey()).getName())) { + inDependencies = true; + } + } else if (value instanceof Toml.Table && inDependencies) { + Toml.Table table = (Toml.Table) value; + if (table.getName() != null && "project".equals(table.getName().getName())) { + inProject = true; + break; + } + } + c = c.getParent(); + } + return inProject; + } + + public static class Matcher extends SimpleTraitMatcher { + @Override + protected @Nullable PyProjectFile test(Cursor cursor) { + Object value = cursor.getValue(); + if (value instanceof Toml.Document) { + Toml.Document doc = (Toml.Document) value; + if (doc.getSourcePath().toString().endsWith("pyproject.toml")) { + PythonResolutionResult marker = doc.getMarkers() + .findFirst(PythonResolutionResult.class).orElse(null); + if (marker != null) { + return new PyProjectFile(cursor, marker); + } + } + } + return null; + } + } +} diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java new file mode 100644 index 0000000000..59092d06f1 --- /dev/null +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java @@ -0,0 +1,102 @@ +/* + * Copyright 2026 the original author or authors. + * + * Moderne Proprietary. Only for use by Moderne customers under the terms of a commercial contract. + */ +package org.openrewrite.python.trait; + +import org.jspecify.annotations.Nullable; +import org.openrewrite.Cursor; +import org.openrewrite.ExecutionContext; +import org.openrewrite.SourceFile; +import org.openrewrite.python.marker.PythonResolutionResult; +import org.openrewrite.trait.SimpleTraitMatcher; +import org.openrewrite.trait.Trait; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * Trait for Python dependency files (pyproject.toml, requirements.txt, etc.). + * Use {@link org.openrewrite.python.internal.PyProjectHelper#extractPackageName(String)} + * for PEP 508 package name extraction. + */ + +public interface PythonDependencyFile extends Trait { + + PythonResolutionResult getMarker(); + + PythonDependencyFile withUpgradedVersions(Map upgrades); + + PythonDependencyFile withAddedDependencies(Map additions); + + /** + * Add search result markers for vulnerable dependencies. + * + * @param packageMessages normalized package name → vulnerability description message + */ + PythonDependencyFile withDependencySearchMarkers(Map packageMessages, ExecutionContext ctx); + + /** + * Rewrite a PEP 508 dependency spec to use a new minimum version. + * Preserves extras and environment markers. + */ + static String rewritePep508Spec(String spec, String packageName, String newVersion) { + int nameEnd = packageName.length(); + StringBuilder sb = new StringBuilder(packageName); + + // Preserve extras like [security] + if (nameEnd < spec.length() && spec.charAt(nameEnd) == '[') { + int extrasEnd = spec.indexOf(']', nameEnd); + if (extrasEnd >= 0) { + extrasEnd++; + sb.append(spec, nameEnd, extrasEnd); + nameEnd = extrasEnd; + } + } + + sb.append(">=").append(newVersion); + + // Preserve environment markers (everything after ';') + int semiIdx = spec.indexOf(';', nameEnd); + if (semiIdx >= 0) { + sb.append(spec.substring(semiIdx)); + } + + return sb.toString(); + } + + /** + * Update the resolved dependency versions in a marker to reflect version changes. + * Returns the same marker if no changes were needed. + */ + static PythonResolutionResult updateResolvedVersions( + PythonResolutionResult marker, Map versionUpdates) { + List resolved = marker.getResolvedDependencies(); + List updated = new ArrayList<>(resolved.size()); + boolean changed = false; + for (PythonResolutionResult.ResolvedDependency dep : resolved) { + String normalizedName = PythonResolutionResult.normalizeName(dep.getName()); + String newVersion = versionUpdates.get(normalizedName); + if (newVersion != null && !newVersion.equals(dep.getVersion())) { + updated.add(dep.withVersion(newVersion)); + changed = true; + } else { + updated.add(dep); + } + } + return changed ? marker.withResolvedDependencies(updated) : marker; + } + + class Matcher extends SimpleTraitMatcher { + private final RequirementsFile.Matcher reqMatcher = new RequirementsFile.Matcher(); + private final PyProjectFile.Matcher tomlMatcher = new PyProjectFile.Matcher(); + + @Override + protected @Nullable PythonDependencyFile test(Cursor cursor) { + PythonDependencyFile r = reqMatcher.test(cursor); + return r != null ? r : tomlMatcher.test(cursor); + } + } +} diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java new file mode 100644 index 0000000000..a80e64a32e --- /dev/null +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java @@ -0,0 +1,142 @@ +/* + * Copyright 2026 the original author or authors. + * + * Moderne Proprietary. Only for use by Moderne customers under the terms of a commercial contract. + */ +package org.openrewrite.python.trait; + +import lombok.Value; +import org.jspecify.annotations.Nullable; +import org.openrewrite.Cursor; +import org.openrewrite.ExecutionContext; +import org.openrewrite.python.RequirementsTxtParser; +import org.openrewrite.python.internal.PyProjectHelper; +import org.openrewrite.text.Find; +import org.openrewrite.python.marker.PythonResolutionResult; +import org.openrewrite.text.PlainText; +import org.openrewrite.trait.SimpleTraitMatcher; + +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +@Value +public class RequirementsFile implements PythonDependencyFile { + private static final RequirementsTxtParser PARSER = new RequirementsTxtParser(); + + Cursor cursor; + PythonResolutionResult marker; + + @Override + public RequirementsFile withUpgradedVersions(Map upgrades) { + PlainText pt = (PlainText) getTree(); + String text = pt.getText(); + String[] lines = text.split("\n", -1); + boolean changed = false; + + for (int i = 0; i < lines.length; i++) { + String line = lines[i]; + String trimmed = line.trim(); + if (trimmed.isEmpty() || trimmed.startsWith("#") || trimmed.startsWith("-")) { + continue; + } + + String packageName = PyProjectHelper.extractPackageName(trimmed); + if (packageName == null) { + continue; + } + + String normalizedName = PythonResolutionResult.normalizeName(packageName); + String fixVersion = upgrades.get(normalizedName); + if (fixVersion == null) { + continue; + } + + String newSpec = PythonDependencyFile.rewritePep508Spec(trimmed, packageName, fixVersion); + if (!newSpec.equals(trimmed)) { + // Preserve leading whitespace from the original line + int leadingWs = 0; + while (leadingWs < line.length() && Character.isWhitespace(line.charAt(leadingWs))) { + leadingWs++; + } + lines[i] = line.substring(0, leadingWs) + newSpec; + changed = true; + } + } + + if (changed) { + PythonResolutionResult updatedMarker = PythonDependencyFile.updateResolvedVersions(marker, upgrades); + PlainText newPt = pt.withText(String.join("\n", lines)); + newPt = newPt.withMarkers(newPt.getMarkers() + .removeByType(PythonResolutionResult.class) + .addIfAbsent(updatedMarker)); + return new RequirementsFile(new Cursor(cursor.getParentOrThrow(), newPt), updatedMarker); + } + return this; + } + + @Override + public RequirementsFile withAddedDependencies(Map additions) { + PlainText pt = (PlainText) getTree(); + String text = pt.getText(); + String[] lines = text.split("\n", -1); + + Set existingPackages = new HashSet<>(); + for (String line : lines) { + String pkg = PyProjectHelper.extractPackageName(line.trim()); + if (pkg != null) { + existingPackages.add(PythonResolutionResult.normalizeName(pkg)); + } + } + + StringBuilder sb = new StringBuilder(text); + boolean changed = false; + for (Map.Entry entry : additions.entrySet()) { + if (!existingPackages.contains(entry.getKey())) { + sb.append("\n").append(entry.getKey()).append(">=").append(entry.getValue()); + changed = true; + } + } + + if (changed) { + PythonResolutionResult updatedMarker = PythonDependencyFile.updateResolvedVersions(marker, additions); + PlainText newPt = pt.withText(sb.toString()); + newPt = newPt.withMarkers(newPt.getMarkers() + .removeByType(PythonResolutionResult.class) + .addIfAbsent(updatedMarker)); + return new RequirementsFile(new Cursor(cursor.getParentOrThrow(), newPt), updatedMarker); + } + return this; + } + + @Override + public RequirementsFile withDependencySearchMarkers(Map packageMessages, ExecutionContext ctx) { + PlainText result = (PlainText) getTree(); + for (Map.Entry entry : packageMessages.entrySet()) { + Find find = new Find(entry.getKey(), null, false, null, null, null, null, null); + result = (PlainText) find.getVisitor().visitNonNull(result, ctx); + } + if (result != getTree()) { + return new RequirementsFile(new Cursor(cursor.getParentOrThrow(), result), marker); + } + return this; + } + + public static class Matcher extends SimpleTraitMatcher { + @Override + protected @Nullable RequirementsFile test(Cursor cursor) { + Object value = cursor.getValue(); + if (value instanceof PlainText) { + PlainText pt = (PlainText) value; + if (PARSER.accept(pt.getSourcePath())) { + PythonResolutionResult marker = pt.getMarkers() + .findFirst(PythonResolutionResult.class).orElse(null); + if (marker != null) { + return new RequirementsFile(cursor, marker); + } + } + } + return null; + } + } +} diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/package-info.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/package-info.java new file mode 100644 index 0000000000..b7bf061317 --- /dev/null +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2024 the original author or authors. + *

+ * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://docs.moderne.io/licensing/moderne-source-available-license + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +@NullMarked +@NonNullFields +package org.openrewrite.python.trait; + +import org.jspecify.annotations.NullMarked; +import org.openrewrite.internal.lang.NonNullFields; diff --git a/rewrite-python/src/test/java/org/openrewrite/python/trait/PythonDependencyFileTest.java b/rewrite-python/src/test/java/org/openrewrite/python/trait/PythonDependencyFileTest.java new file mode 100644 index 0000000000..2e23972e24 --- /dev/null +++ b/rewrite-python/src/test/java/org/openrewrite/python/trait/PythonDependencyFileTest.java @@ -0,0 +1,661 @@ +/* + * Copyright 2026 the original author or authors. + *

+ * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://docs.moderne.io/licensing/moderne-source-available-license + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.python.trait; + +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.openrewrite.*; +import org.openrewrite.marker.Markers; +import org.openrewrite.marker.SearchResult; +import org.openrewrite.python.marker.PythonResolutionResult; +import org.openrewrite.python.marker.PythonResolutionResult.Dependency; +import org.openrewrite.python.marker.PythonResolutionResult.ResolvedDependency; +import org.openrewrite.test.RewriteTest; +import org.openrewrite.text.PlainText; +import org.openrewrite.toml.TomlParser; +import org.openrewrite.toml.tree.Toml; + +import java.nio.file.Paths; +import java.util.*; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.openrewrite.Tree.randomId; +import static org.openrewrite.python.Assertions.pyproject; +import static org.openrewrite.python.Assertions.requirementsTxt; + +class PythonDependencyFileTest implements RewriteTest { + + // region Helper methods + + private static PythonResolutionResult createMarker(List dependencies, + List resolved) { + return new PythonResolutionResult( + randomId(), "test-project", "1.0.0", null, null, + ".", null, null, + Collections.emptyList(), dependencies, + Collections.emptyMap(), Collections.emptyMap(), + Collections.emptyList(), Collections.emptyList(), + resolved, null, null + ); + } + + private static Toml.Document parseToml(String content, PythonResolutionResult marker) { + TomlParser parser = new TomlParser(); + Parser.Input input = Parser.Input.fromString(Paths.get("pyproject.toml"), content); + List parsed = parser.parseInputs( + Collections.singletonList(input), null, + new InMemoryExecutionContext(Throwable::printStackTrace) + ).collect(Collectors.toList()); + Toml.Document doc = (Toml.Document) parsed.get(0); + return doc.withMarkers(doc.getMarkers().addIfAbsent(marker)); + } + + private static PlainText createRequirementsTxt(String content, PythonResolutionResult marker) { + return new PlainText( + randomId(), Paths.get("requirements.txt"), + Markers.EMPTY.addIfAbsent(marker), + "UTF-8", false, null, null, content, null + ); + } + + private static Cursor rootCursor(Object value) { + return new Cursor(new Cursor(null, Cursor.ROOT_VALUE), value); + } + + private static PyProjectFile pyProjectTrait(Toml.Document doc, PythonResolutionResult marker) { + return new PyProjectFile(rootCursor(doc), marker); + } + + private static RequirementsFile requirementsTrait(PlainText pt, PythonResolutionResult marker) { + return new RequirementsFile(rootCursor(pt), marker); + } + + /** + * A recipe that applies {@link PythonDependencyFile#withDependencySearchMarkers} via the trait matcher. + */ + private static Recipe searchMarkersRecipe(Map packageMessages) { + return RewriteTest.toRecipe(() -> new TreeVisitor() { + final PythonDependencyFile.Matcher matcher = new PythonDependencyFile.Matcher(); + + @Override + public Tree preVisit(Tree tree, ExecutionContext ctx) { + PythonDependencyFile trait = matcher.test(getCursor()); + if (trait != null) { + return trait.withDependencySearchMarkers(packageMessages, ctx).getTree(); + } + return tree; + } + }); + } + + // endregion + + @Nested + class RewritePep508SpecTest { + @Test + void simpleUpgrade() { + String result = PythonDependencyFile.rewritePep508Spec("requests>=2.28.0", "requests", "2.31.0"); + assertThat(result).isEqualTo("requests>=2.31.0"); + } + + @Test + void preservesExtras() { + String result = PythonDependencyFile.rewritePep508Spec("requests[security]>=2.28.0", "requests", "2.31.0"); + assertThat(result).isEqualTo("requests[security]>=2.31.0"); + } + + @Test + void preservesEnvironmentMarker() { + String result = PythonDependencyFile.rewritePep508Spec( + "pywin32>=300; sys_platform=='win32'", "pywin32", "306"); + assertThat(result).isEqualTo("pywin32>=306; sys_platform=='win32'"); + } + + @Test + void preservesExtrasAndMarker() { + String result = PythonDependencyFile.rewritePep508Spec( + "requests[security]>=2.28.0; python_version>='3.8'", "requests", "2.31.0"); + assertThat(result).isEqualTo("requests[security]>=2.31.0; python_version>='3.8'"); + } + + @Test + void nameOnly() { + String result = PythonDependencyFile.rewritePep508Spec("requests", "requests", "2.31.0"); + assertThat(result).isEqualTo("requests>=2.31.0"); + } + } + + @Nested + class UpdateResolvedVersionsTest { + @Test + void updatesMatchingVersions() { + ResolvedDependency requests = new ResolvedDependency("requests", "2.28.0", null, null); + ResolvedDependency flask = new ResolvedDependency("flask", "2.0.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), Arrays.asList(requests, flask)); + + Map updates = new HashMap<>(); + updates.put("requests", "2.31.0"); + + PythonResolutionResult updated = PythonDependencyFile.updateResolvedVersions(marker, updates); + + assertThat(updated.getResolvedDependencies()).hasSize(2); + assertThat(updated.getResolvedDependencies().get(0).getVersion()).isEqualTo("2.31.0"); + assertThat(updated.getResolvedDependencies().get(1).getVersion()).isEqualTo("2.0.0"); + } + + @Test + void returnsOriginalWhenNoChanges() { + ResolvedDependency requests = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), Collections.singletonList(requests)); + + Map updates = new HashMap<>(); + updates.put("nonexistent", "1.0.0"); + + PythonResolutionResult updated = PythonDependencyFile.updateResolvedVersions(marker, updates); + + assertThat(updated).isSameAs(marker); + } + + @Test + void returnsOriginalWhenVersionUnchanged() { + ResolvedDependency requests = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), Collections.singletonList(requests)); + + Map updates = new HashMap<>(); + updates.put("requests", "2.28.0"); + + PythonResolutionResult updated = PythonDependencyFile.updateResolvedVersions(marker, updates); + + assertThat(updated).isSameAs(marker); + } + } + + @Nested + class MatcherTest { + @Test + void matchesPyProjectToml() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.31.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), Collections.singletonList(resolved)); + Toml.Document doc = parseToml("[project]\nname = \"test\"\ndependencies = [\"requests>=2.28.0\"]", marker); + + PythonDependencyFile.Matcher matcher = new PythonDependencyFile.Matcher(); + PythonDependencyFile result = matcher.test(rootCursor(doc)); + + assertThat(result).isNotNull(); + assertThat(result).isInstanceOf(PyProjectFile.class); + } + + @Test + void matchesRequirementsTxt() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.31.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), Collections.singletonList(resolved)); + PlainText pt = createRequirementsTxt("requests>=2.28.0", marker); + + PythonDependencyFile.Matcher matcher = new PythonDependencyFile.Matcher(); + PythonDependencyFile result = matcher.test(rootCursor(pt)); + + assertThat(result).isNotNull(); + assertThat(result).isInstanceOf(RequirementsFile.class); + } + + @Test + void doesNotMatchWithoutMarker() { + TomlParser parser = new TomlParser(); + Parser.Input input = Parser.Input.fromString(Paths.get("pyproject.toml"), + "[project]\nname = \"test\""); + Toml.Document doc = (Toml.Document) parser.parseInputs( + Collections.singletonList(input), null, + new InMemoryExecutionContext(Throwable::printStackTrace) + ).collect(Collectors.toList()).get(0); + + PythonDependencyFile.Matcher matcher = new PythonDependencyFile.Matcher(); + assertThat(matcher.test(rootCursor(doc))).isNull(); + } + + @Test + void doesNotMatchNonPythonFile() { + PlainText pt = new PlainText( + randomId(), Paths.get("readme.txt"), + Markers.EMPTY, "UTF-8", false, null, null, "hello", null + ); + + PythonDependencyFile.Matcher matcher = new PythonDependencyFile.Matcher(); + assertThat(matcher.test(rootCursor(pt))).isNull(); + } + } + + @Nested + class PyProjectFileTest { + + @Test + void upgradesDependencyVersion() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + Dependency dep = new Dependency("requests", ">=2.28.0", null, null, resolved); + PythonResolutionResult marker = createMarker(Collections.singletonList(dep), + Collections.singletonList(resolved)); + + String toml = "[project]\nname = \"test\"\ndependencies = [\n \"requests>=2.28.0\",\n]"; + Toml.Document doc = parseToml(toml, marker); + PyProjectFile trait = pyProjectTrait(doc, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + PyProjectFile upgraded = trait.withUpgradedVersions(upgrades); + + Toml.Document result = (Toml.Document) upgraded.getTree(); + String printed = result.printAll(); + assertThat(printed).contains("\"requests>=2.31.0\""); + assertThat(printed).doesNotContain("\"requests>=2.28.0\""); + } + + @Test + void upgradePreservesExtras() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + String toml = "[project]\nname = \"test\"\ndependencies = [\n \"requests[security]>=2.28.0\",\n]"; + Toml.Document doc = parseToml(toml, marker); + PyProjectFile trait = pyProjectTrait(doc, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + PyProjectFile upgraded = trait.withUpgradedVersions(upgrades); + + String printed = ((Toml.Document) upgraded.getTree()).printAll(); + assertThat(printed).contains("\"requests[security]>=2.31.0\""); + } + + @Test + void upgradeNoOpWhenPackageNotFound() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + String toml = "[project]\nname = \"test\"\ndependencies = [\n \"requests>=2.28.0\",\n]"; + Toml.Document doc = parseToml(toml, marker); + PyProjectFile trait = pyProjectTrait(doc, marker); + + Map upgrades = Collections.singletonMap("nonexistent", "1.0.0"); + PyProjectFile upgraded = trait.withUpgradedVersions(upgrades); + + assertThat(upgraded).isSameAs(trait); + } + + @Test + void upgradeUpdatesResolvedVersionsInMarker() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + String toml = "[project]\nname = \"test\"\ndependencies = [\n \"requests>=2.28.0\",\n]"; + Toml.Document doc = parseToml(toml, marker); + PyProjectFile trait = pyProjectTrait(doc, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + PyProjectFile upgraded = trait.withUpgradedVersions(upgrades); + + assertThat(upgraded.getMarker().getResolvedDependencies().get(0).getVersion()).isEqualTo("2.31.0"); + } + + @Test + void searchMarkersOnVulnerableDependency() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + String toml = "[project]\nname = \"test\"\ndependencies = [\n \"requests>=2.28.0\",\n \"flask>=2.0\",\n]"; + Toml.Document doc = parseToml(toml, marker); + PyProjectFile trait = pyProjectTrait(doc, marker); + + Map vulnerabilities = Collections.singletonMap("requests", "CVE-2023-1234"); + ExecutionContext ctx = new InMemoryExecutionContext(Throwable::printStackTrace); + PyProjectFile marked = trait.withDependencySearchMarkers(vulnerabilities, ctx); + + Toml.Document result = (Toml.Document) marked.getTree(); + new org.openrewrite.toml.TomlVisitor() { + @Override + public Toml visitLiteral(Toml.Literal literal, Integer p) { + if (literal.getValue().toString().contains("requests")) { + assertThat(literal.getMarkers().findFirst(SearchResult.class)).isPresent(); + } + return literal; + } + }.visit(result, 0); + assertThat(result).isNotSameAs(doc); + } + + @Test + void searchMarkersNoOpWhenNoMatch() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + String toml = "[project]\nname = \"test\"\ndependencies = [\n \"requests>=2.28.0\",\n]"; + Toml.Document doc = parseToml(toml, marker); + PyProjectFile trait = pyProjectTrait(doc, marker); + + Map vulnerabilities = Collections.singletonMap("nonexistent", "CVE-2023-9999"); + ExecutionContext ctx = new InMemoryExecutionContext(Throwable::printStackTrace); + PyProjectFile marked = trait.withDependencySearchMarkers(vulnerabilities, ctx); + + assertThat(marked).isSameAs(trait); + } + + @Test + void upgradeMultipleDependencies() { + ResolvedDependency requests = new ResolvedDependency("requests", "2.28.0", null, null); + ResolvedDependency flask = new ResolvedDependency("flask", "2.0.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Arrays.asList(requests, flask)); + + String toml = "[project]\nname = \"test\"\ndependencies = [\n \"requests>=2.28.0\",\n \"flask>=2.0.0\",\n]"; + Toml.Document doc = parseToml(toml, marker); + PyProjectFile trait = pyProjectTrait(doc, marker); + + Map upgrades = new HashMap<>(); + upgrades.put("requests", "2.31.0"); + upgrades.put("flask", "3.0.0"); + PyProjectFile upgraded = trait.withUpgradedVersions(upgrades); + + String printed = ((Toml.Document) upgraded.getTree()).printAll(); + assertThat(printed).contains("\"requests>=2.31.0\""); + assertThat(printed).contains("\"flask>=3.0.0\""); + } + + @Test + void doesNotUpgradeDependenciesOutsideProjectSection() { + ResolvedDependency resolved = new ResolvedDependency("setuptools", "68.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + String toml = "[build-system]\nrequires = [\"setuptools>=67.0\"]\n\n[project]\nname = \"test\"\ndependencies = []"; + Toml.Document doc = parseToml(toml, marker); + PyProjectFile trait = pyProjectTrait(doc, marker); + + Map upgrades = Collections.singletonMap("setuptools", "69.0"); + PyProjectFile upgraded = trait.withUpgradedVersions(upgrades); + + // build-system is not inside [project], so it should not be upgraded + assertThat(upgraded).isSameAs(trait); + } + } + + @Nested + class RequirementsFileTest { + + @Test + void upgradesDependencyVersion() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = createRequirementsTxt("requests>=2.28.0\nflask>=2.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades); + + PlainText result = (PlainText) upgraded.getTree(); + assertThat(result.getText()).contains("requests>=2.31.0"); + assertThat(result.getText()).contains("flask>=2.0"); + } + + @Test + void upgradePreservesExtras() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = createRequirementsTxt("requests[security]>=2.28.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades); + + assertThat(((PlainText) upgraded.getTree()).getText()).isEqualTo("requests[security]>=2.31.0"); + } + + @Test + void upgradePreservesEnvironmentMarkers() { + ResolvedDependency resolved = new ResolvedDependency("pywin32", "300", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = createRequirementsTxt("pywin32>=300; sys_platform=='win32'", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("pywin32", "306"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades); + + assertThat(((PlainText) upgraded.getTree()).getText()) + .isEqualTo("pywin32>=306; sys_platform=='win32'"); + } + + @Test + void upgradeSkipsComments() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = createRequirementsTxt("# this is a comment\nrequests>=2.28.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades); + + String text = ((PlainText) upgraded.getTree()).getText(); + assertThat(text).startsWith("# this is a comment\n"); + assertThat(text).contains("requests>=2.31.0"); + } + + @Test + void upgradeSkipsFlags() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = createRequirementsTxt("-r base.txt\nrequests>=2.28.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades); + + String text = ((PlainText) upgraded.getTree()).getText(); + assertThat(text).startsWith("-r base.txt\n"); + assertThat(text).contains("requests>=2.31.0"); + } + + @Test + void upgradeNoOpWhenPackageNotFound() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = createRequirementsTxt("requests>=2.28.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("nonexistent", "1.0.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades); + + assertThat(upgraded).isSameAs(trait); + } + + @Test + void upgradeUpdatesResolvedVersionsInMarker() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = createRequirementsTxt("requests>=2.28.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades); + + assertThat(upgraded.getMarker().getResolvedDependencies().get(0).getVersion()).isEqualTo("2.31.0"); + } + + @Test + void upgradePreservesLeadingWhitespace() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = createRequirementsTxt(" requests>=2.28.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades); + + assertThat(((PlainText) upgraded.getTree()).getText()).isEqualTo(" requests>=2.31.0"); + } + + @Test + void addsDependencyToEnd() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = createRequirementsTxt("requests>=2.28.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map additions = Collections.singletonMap("flask", "3.0.0"); + RequirementsFile added = trait.withAddedDependencies(additions); + + String text = ((PlainText) added.getTree()).getText(); + assertThat(text).isEqualTo("requests>=2.28.0\nflask>=3.0.0"); + } + + @Test + void addDependencyNoOpWhenAlreadyPresent() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = createRequirementsTxt("requests>=2.28.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map additions = Collections.singletonMap("requests", "2.31.0"); + RequirementsFile added = trait.withAddedDependencies(additions); + + assertThat(added).isSameAs(trait); + } + + @Test + void searchMarkersOnVulnerableDependency() { + rewriteRun( + spec -> spec.recipe(searchMarkersRecipe( + Collections.singletonMap("requests", "CVE-2023-1234"))), + requirementsTxt( + "requests>=2.28.0\nflask>=2.0", + "~~>requests>=2.28.0\nflask>=2.0" + ) + ); + } + + @Test + void searchMarkersNoOpWhenNoMatch() { + rewriteRun( + spec -> spec.recipe(searchMarkersRecipe( + Collections.singletonMap("nonexistent", "CVE-2023-9999"))), + requirementsTxt("requests>=2.28.0") + ); + } + + @Test + void upgradeMultipleDependencies() { + ResolvedDependency requests = new ResolvedDependency("requests", "2.28.0", null, null); + ResolvedDependency flask = new ResolvedDependency("flask", "2.0.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Arrays.asList(requests, flask)); + + PlainText pt = createRequirementsTxt("requests>=2.28.0\nflask>=2.0.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = new HashMap<>(); + upgrades.put("requests", "2.31.0"); + upgrades.put("flask", "3.0.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades); + + String text = ((PlainText) upgraded.getTree()).getText(); + assertThat(text).contains("requests>=2.31.0"); + assertThat(text).contains("flask>=3.0.0"); + } + + @Test + void upgradeHandlesEmptyLines() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = createRequirementsTxt("requests>=2.28.0\n\nflask>=2.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades); + + String text = ((PlainText) upgraded.getTree()).getText(); + assertThat(text).isEqualTo("requests>=2.31.0\n\nflask>=2.0"); + } + } + + @Nested + class PyProjectSearchMarkersTest { + + @Test + void searchMarkersViaMatcher() { + rewriteRun( + spec -> spec.recipe(searchMarkersRecipe( + Collections.singletonMap("requests", "CVE-2023-1234"))), + pyproject( + """ + [project] + name = "test" + dependencies = [ + "requests>=2.28.0", + "flask>=2.0", + ] + """, + """ + [project] + name = "test" + dependencies = [ + ~~(CVE-2023-1234)~~>"requests>=2.28.0", + "flask>=2.0", + ] + """ + ) + ); + } + + @Test + void searchMarkersNoOpViaMatcher() { + rewriteRun( + spec -> spec.recipe(searchMarkersRecipe( + Collections.singletonMap("nonexistent", "CVE-2023-9999"))), + pyproject( + """ + [project] + name = "test" + dependencies = [ + "requests>=2.28.0", + ] + """ + ) + ); + } + } +} From fc1d0ca798ac31eb9cc6975449b0a69b53b53820 Mon Sep 17 00:00:00 2001 From: Jente Sondervorst Date: Fri, 3 Apr 2026 14:50:39 +0200 Subject: [PATCH 02/10] initial version of the trait --- .../org/openrewrite/python/AddDependency.java | 98 +++---------------- .../python/trait/PyProjectFile.java | 27 ++--- .../python/trait/PythonDependencyFile.java | 12 ++- .../python/trait/RequirementsFile.java | 4 +- .../trait/PythonDependencyFileTest.java | 4 +- 5 files changed, 33 insertions(+), 112 deletions(-) diff --git a/rewrite-python/src/main/java/org/openrewrite/python/AddDependency.java b/rewrite-python/src/main/java/org/openrewrite/python/AddDependency.java index 5ab5a380a0..3428af2c5b 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/AddDependency.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/AddDependency.java @@ -19,20 +19,15 @@ import lombok.Value; import org.jspecify.annotations.Nullable; import org.openrewrite.*; -import org.openrewrite.marker.Markers; import org.openrewrite.python.internal.PyProjectHelper; import org.openrewrite.python.internal.PythonDependencyExecutionContextView; import org.openrewrite.python.marker.PythonResolutionResult; +import org.openrewrite.python.trait.PyProjectFile; import org.openrewrite.toml.TomlIsoVisitor; -import org.openrewrite.toml.tree.Space; import org.openrewrite.toml.tree.Toml; -import org.openrewrite.toml.tree.TomlRightPadded; -import org.openrewrite.toml.tree.TomlType; import java.util.*; -import static org.openrewrite.Tree.randomId; - /** * Add a dependency to the {@code [project].dependencies} array in pyproject.toml. * When uv is available, the uv.lock file is regenerated to reflect the change. @@ -147,7 +142,16 @@ public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) String sourcePath = document.getSourcePath().toString(); if (sourcePath.endsWith("pyproject.toml") && acc.projectsToUpdate.contains(sourcePath)) { - return addDependencyToPyproject(document, ctx, acc); + PyProjectFile trait = new PyProjectFile.Matcher().get(getCursor()).orElse(null); + if (trait != null) { + String ver = version != null ? version : ""; + Map additions = Collections.singletonMap(packageName, ver); + PyProjectFile updated = trait.withAddedDependencies(additions, scope, groupName); + Toml.Document result = (Toml.Document) updated.getTree(); + if (result != document) { + return PyProjectHelper.regenerateLockAndRefreshMarker(result, ctx); + } + } } if (sourcePath.endsWith("uv.lock")) { @@ -162,84 +166,4 @@ public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) }; } - private Toml.Document addDependencyToPyproject(Toml.Document document, ExecutionContext ctx, Accumulator acc) { - String pep508 = version != null ? packageName + PyProjectHelper.normalizeVersionConstraint(version) : packageName; - - Toml.Document updated = (Toml.Document) new TomlIsoVisitor() { - @Override - public Toml.Array visitArray(Toml.Array array, ExecutionContext ctx) { - Toml.Array a = super.visitArray(array, ctx); - - if (!PyProjectHelper.isInsideDependencyArray(getCursor(), scope, groupName)) { - return a; - } - - Toml.Literal newLiteral = new Toml.Literal( - randomId(), - Space.EMPTY, - Markers.EMPTY, - TomlType.Primitive.String, - "\"" + pep508 + "\"", - pep508 - ); - - List> existingPadded = a.getPadding().getValues(); - List> newPadded = new ArrayList<>(); - - // An empty TOML array [] is represented as a single Toml.Empty element - boolean isEmpty = existingPadded.size() == 1 && - existingPadded.get(0).getElement() instanceof Toml.Empty; - if (existingPadded.isEmpty() || isEmpty) { - newPadded.add(new TomlRightPadded<>(newLiteral, Space.EMPTY, Markers.EMPTY)); - } else { - // Check if the last element is Toml.Empty (trailing comma marker) - TomlRightPadded lastPadded = existingPadded.get(existingPadded.size() - 1); - boolean hasTrailingComma = lastPadded.getElement() instanceof Toml.Empty; - - if (hasTrailingComma) { - // Insert before the Empty element. The Empty's position - // stores the whitespace before ']'. - // Find the last real element to copy its prefix formatting - int lastRealIdx = existingPadded.size() - 2; - Toml lastRealElement = existingPadded.get(lastRealIdx).getElement(); - Toml.Literal formattedLiteral = newLiteral.withPrefix(lastRealElement.getPrefix()); - - // Copy all existing elements up to (not including) the Empty - for (int i = 0; i <= lastRealIdx; i++) { - newPadded.add(existingPadded.get(i)); - } - // Add new literal with empty after (comma added by printer) - newPadded.add(new TomlRightPadded<>(formattedLiteral, Space.EMPTY, Markers.EMPTY)); - // Keep the Empty element for trailing comma + closing bracket whitespace - newPadded.add(lastPadded); - } else { - // No trailing comma — the last real element's after has the space before ']' - Toml lastElement = lastPadded.getElement(); - // For multi-line arrays, use same prefix; for inline, use single space - Space newPrefix = lastElement.getPrefix().getWhitespace().contains("\n") - ? lastElement.getPrefix() - : Space.SINGLE_SPACE; - Toml.Literal formattedLiteral = newLiteral.withPrefix(newPrefix); - - // Copy all existing elements but set last one's after to empty - for (int i = 0; i < existingPadded.size() - 1; i++) { - newPadded.add(existingPadded.get(i)); - } - newPadded.add(lastPadded.withAfter(Space.EMPTY)); - // New element gets the after from the old last element - newPadded.add(new TomlRightPadded<>(formattedLiteral, lastPadded.getAfter(), Markers.EMPTY)); - } - } - - return a.getPadding().withValues(newPadded); - } - }.visitNonNull(document, ctx); - - if (updated != document) { - updated = PyProjectHelper.regenerateLockAndRefreshMarker(updated, ctx); - } - - return updated; - } - } diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java index 06032f6712..8573da3fa9 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java @@ -71,14 +71,13 @@ public Toml.Literal visitLiteral(Toml.Literal literal, Map u) { } @Override - public PyProjectFile withAddedDependencies(Map additions) { + public PyProjectFile withAddedDependencies(Map additions, @Nullable String scope, @Nullable String groupName) { Toml.Document doc = (Toml.Document) getTree(); Toml.Document original = doc; - String scope = transitiveConstraintScope(); for (Map.Entry entry : additions.entrySet()) { - if (PyProjectHelper.findDependencyInScope(marker, entry.getKey(), scope, null) == null) { - String pep508 = entry.getKey() + ">=" + entry.getValue(); - doc = addDependencyToArray(doc, pep508, scope); + if (PyProjectHelper.findDependencyInScope(marker, entry.getKey(), scope, groupName) == null) { + String pep508 = entry.getKey() + PyProjectHelper.normalizeVersionConstraint(entry.getValue()); + doc = addDependencyToArray(doc, pep508, scope, groupName); } } if (doc != original) { @@ -91,25 +90,13 @@ public PyProjectFile withAddedDependencies(Map additions) { return this; } - /** - * Determine the TOML scope for transitive dependency constraints based on - * the package manager. - */ - private @Nullable String transitiveConstraintScope() { - PythonResolutionResult.PackageManager pm = marker.getPackageManager(); - if (pm == PythonResolutionResult.PackageManager.Uv) { - return "tool.uv.constraint-dependencies"; - } - // TODO: PDM uses [tool.pdm.overrides] (key-value, not array) — needs separate handling - return null; - } - - private static Toml.Document addDependencyToArray(Toml.Document d, String pep508, @Nullable String scope) { + private static Toml.Document addDependencyToArray(Toml.Document d, String pep508, + @Nullable String scope, @Nullable String groupName) { return (Toml.Document) new TomlIsoVisitor() { @Override public Toml.Array visitArray(Toml.Array array, Integer p) { Toml.Array a = super.visitArray(array, p); - if (!PyProjectHelper.isInsideDependencyArray(getCursor(), scope, null)) { + if (!PyProjectHelper.isInsideDependencyArray(getCursor(), scope, groupName)) { return a; } diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java index 59092d06f1..647efd7298 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java @@ -29,7 +29,17 @@ public interface PythonDependencyFile extends Trait { PythonDependencyFile withUpgradedVersions(Map upgrades); - PythonDependencyFile withAddedDependencies(Map additions); + /** + * Add dependencies to the specified scope. + * + * @param additions normalized package name → version constraint (e.g. {@code "2.0"} or {@code ">=2.0"}) + * @param scope the TOML scope (e.g. {@code "project.optional-dependencies"}, + * {@code "dependency-groups"}), or {@code null} for the default + * ({@code [project].dependencies}) + * @param groupName required when scope is {@code "project.optional-dependencies"} + * or {@code "dependency-groups"}, otherwise {@code null} + */ + PythonDependencyFile withAddedDependencies(Map additions, @Nullable String scope, @Nullable String groupName); /** * Add search result markers for vulnerable dependencies. diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java index a80e64a32e..e9d06db7de 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java @@ -76,7 +76,7 @@ public RequirementsFile withUpgradedVersions(Map upgrades) { } @Override - public RequirementsFile withAddedDependencies(Map additions) { + public RequirementsFile withAddedDependencies(Map additions, @Nullable String scope, @Nullable String groupName) { PlainText pt = (PlainText) getTree(); String text = pt.getText(); String[] lines = text.split("\n", -1); @@ -93,7 +93,7 @@ public RequirementsFile withAddedDependencies(Map additions) { boolean changed = false; for (Map.Entry entry : additions.entrySet()) { if (!existingPackages.contains(entry.getKey())) { - sb.append("\n").append(entry.getKey()).append(">=").append(entry.getValue()); + sb.append("\n").append(entry.getKey()).append(PyProjectHelper.normalizeVersionConstraint(entry.getValue())); changed = true; } } diff --git a/rewrite-python/src/test/java/org/openrewrite/python/trait/PythonDependencyFileTest.java b/rewrite-python/src/test/java/org/openrewrite/python/trait/PythonDependencyFileTest.java index 2e23972e24..cb8edf1f73 100644 --- a/rewrite-python/src/test/java/org/openrewrite/python/trait/PythonDependencyFileTest.java +++ b/rewrite-python/src/test/java/org/openrewrite/python/trait/PythonDependencyFileTest.java @@ -533,7 +533,7 @@ void addsDependencyToEnd() { RequirementsFile trait = requirementsTrait(pt, marker); Map additions = Collections.singletonMap("flask", "3.0.0"); - RequirementsFile added = trait.withAddedDependencies(additions); + RequirementsFile added = trait.withAddedDependencies(additions, null, null); String text = ((PlainText) added.getTree()).getText(); assertThat(text).isEqualTo("requests>=2.28.0\nflask>=3.0.0"); @@ -549,7 +549,7 @@ void addDependencyNoOpWhenAlreadyPresent() { RequirementsFile trait = requirementsTrait(pt, marker); Map additions = Collections.singletonMap("requests", "2.31.0"); - RequirementsFile added = trait.withAddedDependencies(additions); + RequirementsFile added = trait.withAddedDependencies(additions, null, null); assertThat(added).isSameAs(trait); } From c1d13d553b124846a708312b6461c54e12f08968 Mon Sep 17 00:00:00 2001 From: Jente Sondervorst Date: Fri, 3 Apr 2026 14:57:14 +0200 Subject: [PATCH 03/10] initial version of the trait --- .../python/UpgradeDependencyVersion.java | 82 +++---------------- .../python/trait/PyProjectFile.java | 30 ++----- .../python/trait/PythonDependencyFile.java | 18 +++- .../python/trait/RequirementsFile.java | 2 +- .../trait/PythonDependencyFileTest.java | 32 ++++---- 5 files changed, 49 insertions(+), 115 deletions(-) diff --git a/rewrite-python/src/main/java/org/openrewrite/python/UpgradeDependencyVersion.java b/rewrite-python/src/main/java/org/openrewrite/python/UpgradeDependencyVersion.java index 159500b96f..f6fe68c655 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/UpgradeDependencyVersion.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/UpgradeDependencyVersion.java @@ -22,9 +22,9 @@ import org.openrewrite.python.internal.PyProjectHelper; import org.openrewrite.python.internal.PythonDependencyExecutionContextView; import org.openrewrite.python.marker.PythonResolutionResult; +import org.openrewrite.python.trait.PyProjectFile; import org.openrewrite.toml.TomlIsoVisitor; import org.openrewrite.toml.tree.Toml; -import org.openrewrite.toml.tree.TomlType; import java.util.*; @@ -147,7 +147,16 @@ public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) String sourcePath = document.getSourcePath().toString(); if (sourcePath.endsWith("pyproject.toml") && acc.projectsToUpdate.contains(sourcePath)) { - return changeVersionInPyproject(document, ctx, acc); + PyProjectFile trait = new PyProjectFile.Matcher().get(getCursor()).orElse(null); + if (trait != null) { + Map upgrades = Collections.singletonMap( + PythonResolutionResult.normalizeName(packageName), newVersion); + PyProjectFile updated = trait.withUpgradedVersions(upgrades, scope, groupName); + Toml.Document result = (Toml.Document) updated.getTree(); + if (result != document) { + return PyProjectHelper.regenerateLockAndRefreshMarker(result, ctx); + } + } } if (sourcePath.endsWith("uv.lock")) { @@ -162,75 +171,6 @@ public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) }; } - private Toml.Document changeVersionInPyproject(Toml.Document document, ExecutionContext ctx, Accumulator acc) { - String normalizedName = PythonResolutionResult.normalizeName(packageName); - - Toml.Document updated = (Toml.Document) new TomlIsoVisitor() { - @Override - public Toml.Literal visitLiteral(Toml.Literal literal, ExecutionContext ctx) { - Toml.Literal l = super.visitLiteral(literal, ctx); - if (l.getType() != TomlType.Primitive.String) { - return l; - } - - Object val = l.getValue(); - if (!(val instanceof String)) { - return l; - } - - // Check if we're inside the target dependency array - if (!isInsideTargetDependencies()) { - return l; - } - - // Check if this literal matches the package we're looking for - String spec = (String) val; - String depName = PyProjectHelper.extractPackageName(spec); - if (depName == null || !PythonResolutionResult.normalizeName(depName).equals(normalizedName)) { - return l; - } - - // Build new PEP 508 string preserving extras and markers - String newSpec = buildNewSpec(spec, depName); - return l.withSource("\"" + newSpec + "\"").withValue(newSpec); - } - - private boolean isInsideTargetDependencies() { - // Walk up the cursor to find the enclosing array, then check scope - Cursor c = getCursor(); - while (c != null) { - if (c.getValue() instanceof Toml.Array) { - return PyProjectHelper.isInsideDependencyArray(c, scope, groupName); - } - c = c.getParent(); - } - return false; - } - - private String buildNewSpec(String oldSpec, String depName) { - // Parse extras and markers from old spec - String extras = extractExtras(oldSpec); - String marker = extractMarker(oldSpec); - - StringBuilder sb = new StringBuilder(depName); - if (extras != null) { - sb.append('[').append(extras).append(']'); - } - sb.append(PyProjectHelper.normalizeVersionConstraint(newVersion)); - if (marker != null) { - sb.append("; ").append(marker); - } - return sb.toString(); - } - }.visitNonNull(document, ctx); - - if (updated != document) { - updated = PyProjectHelper.regenerateLockAndRefreshMarker(updated, ctx); - } - - return updated; - } - static @Nullable String extractExtras(String pep508Spec) { int start = pep508Spec.indexOf('['); int end = pep508Spec.indexOf(']'); diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java index 8573da3fa9..167ba75b4c 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java @@ -32,12 +32,12 @@ public class PyProjectFile implements PythonDependencyFile { PythonResolutionResult marker; @Override - public PyProjectFile withUpgradedVersions(Map upgrades) { + public PyProjectFile withUpgradedVersions(Map upgrades, @Nullable String scope, @Nullable String groupName) { Toml.Document doc = (Toml.Document) getTree(); Toml.Document result = (Toml.Document) new TomlIsoVisitor>() { @Override public Toml.Literal visitLiteral(Toml.Literal literal, Map u) { - if (!isInsideProjectDependencies(getCursor())) { + if (!isInsideTargetArray(getCursor(), scope, groupName)) { return literal; } @@ -149,7 +149,7 @@ public PyProjectFile withDependencySearchMarkers(Map packageMess Toml.Document result = (Toml.Document) new TomlIsoVisitor>() { @Override public Toml.Literal visitLiteral(Toml.Literal literal, Map msgs) { - if (!isInsideProjectDependencies(getCursor())) { + if (!isInsideTargetArray(getCursor(), null, null)) { return literal; } @@ -173,31 +173,15 @@ public Toml.Literal visitLiteral(Toml.Literal literal, Map msgs) return this; } - private static boolean isInsideProjectDependencies(Cursor cursor) { + private static boolean isInsideTargetArray(Cursor cursor, @Nullable String scope, @Nullable String groupName) { Cursor c = cursor; - boolean inArray = false; - boolean inDependencies = false; - boolean inProject = false; while (c != null) { - Object value = c.getValue(); - if (value instanceof Toml.Array) { - inArray = true; - } else if (value instanceof Toml.KeyValue && inArray) { - Toml.KeyValue kv = (Toml.KeyValue) value; - if (kv.getKey() instanceof Toml.Identifier && - "dependencies".equals(((Toml.Identifier) kv.getKey()).getName())) { - inDependencies = true; - } - } else if (value instanceof Toml.Table && inDependencies) { - Toml.Table table = (Toml.Table) value; - if (table.getName() != null && "project".equals(table.getName().getName())) { - inProject = true; - break; - } + if (c.getValue() instanceof Toml.Array) { + return PyProjectHelper.isInsideDependencyArray(c, scope, groupName); } c = c.getParent(); } - return inProject; + return false; } public static class Matcher extends SimpleTraitMatcher { diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java index 647efd7298..4ff72ca724 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java @@ -9,6 +9,7 @@ import org.openrewrite.Cursor; import org.openrewrite.ExecutionContext; import org.openrewrite.SourceFile; +import org.openrewrite.python.internal.PyProjectHelper; import org.openrewrite.python.marker.PythonResolutionResult; import org.openrewrite.trait.SimpleTraitMatcher; import org.openrewrite.trait.Trait; @@ -27,7 +28,14 @@ public interface PythonDependencyFile extends Trait { PythonResolutionResult getMarker(); - PythonDependencyFile withUpgradedVersions(Map upgrades); + /** + * Upgrade version constraints for dependencies in the specified scope. + * + * @param upgrades normalized package name → new version + * @param scope the TOML scope, or {@code null} for the default ({@code [project].dependencies}) + * @param groupName required for {@code "project.optional-dependencies"} or {@code "dependency-groups"} + */ + PythonDependencyFile withUpgradedVersions(Map upgrades, @Nullable String scope, @Nullable String groupName); /** * Add dependencies to the specified scope. @@ -49,8 +57,10 @@ public interface PythonDependencyFile extends Trait { PythonDependencyFile withDependencySearchMarkers(Map packageMessages, ExecutionContext ctx); /** - * Rewrite a PEP 508 dependency spec to use a new minimum version. - * Preserves extras and environment markers. + * Rewrite a PEP 508 dependency spec with a new version constraint. + * Preserves extras and environment markers. The version is normalized + * via {@link PyProjectHelper#normalizeVersionConstraint(String)}, + * so both {@code "2.31.0"} and {@code ">=2.31.0"} are accepted. */ static String rewritePep508Spec(String spec, String packageName, String newVersion) { int nameEnd = packageName.length(); @@ -66,7 +76,7 @@ static String rewritePep508Spec(String spec, String packageName, String newVersi } } - sb.append(">=").append(newVersion); + sb.append(PyProjectHelper.normalizeVersionConstraint(newVersion)); // Preserve environment markers (everything after ';') int semiIdx = spec.indexOf(';', nameEnd); diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java index e9d06db7de..0adfd8164f 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java @@ -28,7 +28,7 @@ public class RequirementsFile implements PythonDependencyFile { PythonResolutionResult marker; @Override - public RequirementsFile withUpgradedVersions(Map upgrades) { + public RequirementsFile withUpgradedVersions(Map upgrades, @Nullable String scope, @Nullable String groupName) { PlainText pt = (PlainText) getTree(); String text = pt.getText(); String[] lines = text.split("\n", -1); diff --git a/rewrite-python/src/test/java/org/openrewrite/python/trait/PythonDependencyFileTest.java b/rewrite-python/src/test/java/org/openrewrite/python/trait/PythonDependencyFileTest.java index cb8edf1f73..910c37ded7 100644 --- a/rewrite-python/src/test/java/org/openrewrite/python/trait/PythonDependencyFileTest.java +++ b/rewrite-python/src/test/java/org/openrewrite/python/trait/PythonDependencyFileTest.java @@ -253,7 +253,7 @@ void upgradesDependencyVersion() { PyProjectFile trait = pyProjectTrait(doc, marker); Map upgrades = Collections.singletonMap("requests", "2.31.0"); - PyProjectFile upgraded = trait.withUpgradedVersions(upgrades); + PyProjectFile upgraded = trait.withUpgradedVersions(upgrades, null, null); Toml.Document result = (Toml.Document) upgraded.getTree(); String printed = result.printAll(); @@ -272,7 +272,7 @@ void upgradePreservesExtras() { PyProjectFile trait = pyProjectTrait(doc, marker); Map upgrades = Collections.singletonMap("requests", "2.31.0"); - PyProjectFile upgraded = trait.withUpgradedVersions(upgrades); + PyProjectFile upgraded = trait.withUpgradedVersions(upgrades, null, null); String printed = ((Toml.Document) upgraded.getTree()).printAll(); assertThat(printed).contains("\"requests[security]>=2.31.0\""); @@ -289,7 +289,7 @@ void upgradeNoOpWhenPackageNotFound() { PyProjectFile trait = pyProjectTrait(doc, marker); Map upgrades = Collections.singletonMap("nonexistent", "1.0.0"); - PyProjectFile upgraded = trait.withUpgradedVersions(upgrades); + PyProjectFile upgraded = trait.withUpgradedVersions(upgrades, null, null); assertThat(upgraded).isSameAs(trait); } @@ -305,7 +305,7 @@ void upgradeUpdatesResolvedVersionsInMarker() { PyProjectFile trait = pyProjectTrait(doc, marker); Map upgrades = Collections.singletonMap("requests", "2.31.0"); - PyProjectFile upgraded = trait.withUpgradedVersions(upgrades); + PyProjectFile upgraded = trait.withUpgradedVersions(upgrades, null, null); assertThat(upgraded.getMarker().getResolvedDependencies().get(0).getVersion()).isEqualTo("2.31.0"); } @@ -368,7 +368,7 @@ void upgradeMultipleDependencies() { Map upgrades = new HashMap<>(); upgrades.put("requests", "2.31.0"); upgrades.put("flask", "3.0.0"); - PyProjectFile upgraded = trait.withUpgradedVersions(upgrades); + PyProjectFile upgraded = trait.withUpgradedVersions(upgrades, null, null); String printed = ((Toml.Document) upgraded.getTree()).printAll(); assertThat(printed).contains("\"requests>=2.31.0\""); @@ -386,7 +386,7 @@ void doesNotUpgradeDependenciesOutsideProjectSection() { PyProjectFile trait = pyProjectTrait(doc, marker); Map upgrades = Collections.singletonMap("setuptools", "69.0"); - PyProjectFile upgraded = trait.withUpgradedVersions(upgrades); + PyProjectFile upgraded = trait.withUpgradedVersions(upgrades, null, null); // build-system is not inside [project], so it should not be upgraded assertThat(upgraded).isSameAs(trait); @@ -406,7 +406,7 @@ void upgradesDependencyVersion() { RequirementsFile trait = requirementsTrait(pt, marker); Map upgrades = Collections.singletonMap("requests", "2.31.0"); - RequirementsFile upgraded = trait.withUpgradedVersions(upgrades); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, null, null); PlainText result = (PlainText) upgraded.getTree(); assertThat(result.getText()).contains("requests>=2.31.0"); @@ -423,7 +423,7 @@ void upgradePreservesExtras() { RequirementsFile trait = requirementsTrait(pt, marker); Map upgrades = Collections.singletonMap("requests", "2.31.0"); - RequirementsFile upgraded = trait.withUpgradedVersions(upgrades); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, null, null); assertThat(((PlainText) upgraded.getTree()).getText()).isEqualTo("requests[security]>=2.31.0"); } @@ -438,7 +438,7 @@ void upgradePreservesEnvironmentMarkers() { RequirementsFile trait = requirementsTrait(pt, marker); Map upgrades = Collections.singletonMap("pywin32", "306"); - RequirementsFile upgraded = trait.withUpgradedVersions(upgrades); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, null, null); assertThat(((PlainText) upgraded.getTree()).getText()) .isEqualTo("pywin32>=306; sys_platform=='win32'"); @@ -454,7 +454,7 @@ void upgradeSkipsComments() { RequirementsFile trait = requirementsTrait(pt, marker); Map upgrades = Collections.singletonMap("requests", "2.31.0"); - RequirementsFile upgraded = trait.withUpgradedVersions(upgrades); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, null, null); String text = ((PlainText) upgraded.getTree()).getText(); assertThat(text).startsWith("# this is a comment\n"); @@ -471,7 +471,7 @@ void upgradeSkipsFlags() { RequirementsFile trait = requirementsTrait(pt, marker); Map upgrades = Collections.singletonMap("requests", "2.31.0"); - RequirementsFile upgraded = trait.withUpgradedVersions(upgrades); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, null, null); String text = ((PlainText) upgraded.getTree()).getText(); assertThat(text).startsWith("-r base.txt\n"); @@ -488,7 +488,7 @@ void upgradeNoOpWhenPackageNotFound() { RequirementsFile trait = requirementsTrait(pt, marker); Map upgrades = Collections.singletonMap("nonexistent", "1.0.0"); - RequirementsFile upgraded = trait.withUpgradedVersions(upgrades); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, null, null); assertThat(upgraded).isSameAs(trait); } @@ -503,7 +503,7 @@ void upgradeUpdatesResolvedVersionsInMarker() { RequirementsFile trait = requirementsTrait(pt, marker); Map upgrades = Collections.singletonMap("requests", "2.31.0"); - RequirementsFile upgraded = trait.withUpgradedVersions(upgrades); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, null, null); assertThat(upgraded.getMarker().getResolvedDependencies().get(0).getVersion()).isEqualTo("2.31.0"); } @@ -518,7 +518,7 @@ void upgradePreservesLeadingWhitespace() { RequirementsFile trait = requirementsTrait(pt, marker); Map upgrades = Collections.singletonMap("requests", "2.31.0"); - RequirementsFile upgraded = trait.withUpgradedVersions(upgrades); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, null, null); assertThat(((PlainText) upgraded.getTree()).getText()).isEqualTo(" requests>=2.31.0"); } @@ -588,7 +588,7 @@ void upgradeMultipleDependencies() { Map upgrades = new HashMap<>(); upgrades.put("requests", "2.31.0"); upgrades.put("flask", "3.0.0"); - RequirementsFile upgraded = trait.withUpgradedVersions(upgrades); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, null, null); String text = ((PlainText) upgraded.getTree()).getText(); assertThat(text).contains("requests>=2.31.0"); @@ -605,7 +605,7 @@ void upgradeHandlesEmptyLines() { RequirementsFile trait = requirementsTrait(pt, marker); Map upgrades = Collections.singletonMap("requests", "2.31.0"); - RequirementsFile upgraded = trait.withUpgradedVersions(upgrades); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, null, null); String text = ((PlainText) upgraded.getTree()).getText(); assertThat(text).isEqualTo("requests>=2.31.0\n\nflask>=2.0"); From e862ed23996a8b1ea4e6aa5b9edd0c32155463ec Mon Sep 17 00:00:00 2001 From: Jente Sondervorst Date: Fri, 3 Apr 2026 17:04:33 +0200 Subject: [PATCH 04/10] initial version of the trait --- .../org/openrewrite/python/AddDependency.java | 77 ++-- .../python/UpgradeDependencyVersion.java | 81 ++-- .../UpgradeTransitiveDependencyVersion.java | 350 +++--------------- .../python/trait/PyProjectFile.java | 113 +++++- .../python/trait/PythonDependencyFile.java | 10 + .../python/trait/RequirementsFile.java | 5 + 6 files changed, 252 insertions(+), 384 deletions(-) diff --git a/rewrite-python/src/main/java/org/openrewrite/python/AddDependency.java b/rewrite-python/src/main/java/org/openrewrite/python/AddDependency.java index 3428af2c5b..ebb9b1f2ca 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/AddDependency.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/AddDependency.java @@ -22,7 +22,7 @@ import org.openrewrite.python.internal.PyProjectHelper; import org.openrewrite.python.internal.PythonDependencyExecutionContextView; import org.openrewrite.python.marker.PythonResolutionResult; -import org.openrewrite.python.trait.PyProjectFile; +import org.openrewrite.python.trait.PythonDependencyFile; import org.openrewrite.toml.TomlIsoVisitor; import org.openrewrite.toml.tree.Toml; @@ -100,68 +100,69 @@ public Accumulator getInitialValue(ExecutionContext ctx) { @Override public TreeVisitor getScanner(Accumulator acc) { - return new TomlIsoVisitor() { + return new TreeVisitor() { @Override - public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) { - String sourcePath = document.getSourcePath().toString(); - - if (sourcePath.endsWith("uv.lock")) { - PythonDependencyExecutionContextView.view(ctx).getExistingLockContents().put( - PyProjectHelper.correspondingPyprojectPath(sourcePath), - document.printAll()); - return document; + public @Nullable Tree preVisit(Tree tree, ExecutionContext ctx) { + if (!(tree instanceof SourceFile)) { + return tree; } - - if (!sourcePath.endsWith("pyproject.toml")) { - return document; + stopAfterPreVisit(); + SourceFile sourceFile = (SourceFile) tree; + if (tree instanceof Toml.Document && sourceFile.getSourcePath().toString().endsWith("uv.lock")) { + PythonDependencyExecutionContextView.view(ctx).getExistingLockContents().put( + PyProjectHelper.correspondingPyprojectPath(sourceFile.getSourcePath().toString()), + ((Toml.Document) tree).printAll()); + return tree; } - Optional resolution = document.getMarkers() - .findFirst(PythonResolutionResult.class); - if (!resolution.isPresent()) { - return document; + PythonDependencyFile trait = new PythonDependencyFile.Matcher().get(getCursor()).orElse(null); + if (trait == null) { + return tree; } - - PythonResolutionResult marker = resolution.get(); - - // Check if the dependency already exists in the target scope - if (PyProjectHelper.findDependencyInScope(marker, packageName, scope, groupName) != null) { - return document; + if (PyProjectHelper.findDependencyInScope(trait.getMarker(), packageName, scope, groupName) != null) { + return tree; } - - acc.projectsToUpdate.add(sourcePath); - return document; + acc.projectsToUpdate.add(sourceFile.getSourcePath().toString()); + return tree; } }; } @Override public TreeVisitor getVisitor(Accumulator acc) { - return new TomlIsoVisitor() { + return new TreeVisitor() { @Override - public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) { - String sourcePath = document.getSourcePath().toString(); + public @Nullable Tree preVisit(Tree tree, ExecutionContext ctx) { + if (!(tree instanceof SourceFile)) { + return tree; + } + stopAfterPreVisit(); + SourceFile sourceFile = (SourceFile) tree; + String sourcePath = sourceFile.getSourcePath().toString(); - if (sourcePath.endsWith("pyproject.toml") && acc.projectsToUpdate.contains(sourcePath)) { - PyProjectFile trait = new PyProjectFile.Matcher().get(getCursor()).orElse(null); + if (acc.projectsToUpdate.contains(sourcePath)) { + PythonDependencyFile trait = new PythonDependencyFile.Matcher().get(getCursor()).orElse(null); if (trait != null) { String ver = version != null ? version : ""; Map additions = Collections.singletonMap(packageName, ver); - PyProjectFile updated = trait.withAddedDependencies(additions, scope, groupName); - Toml.Document result = (Toml.Document) updated.getTree(); - if (result != document) { - return PyProjectHelper.regenerateLockAndRefreshMarker(result, ctx); + PythonDependencyFile updated = trait.withAddedDependencies(additions, scope, groupName); + SourceFile result = (SourceFile) updated.getTree(); + if (result != tree) { + if (result instanceof Toml.Document) { + return PyProjectHelper.regenerateLockAndRefreshMarker((Toml.Document) result, ctx); + } + return result; } } } - if (sourcePath.endsWith("uv.lock")) { - Toml.Document updatedLock = PyProjectHelper.maybeUpdateUvLock(document, ctx); + if (tree instanceof Toml.Document && sourcePath.endsWith("uv.lock")) { + Toml.Document updatedLock = PyProjectHelper.maybeUpdateUvLock((Toml.Document) tree, ctx); if (updatedLock != null) { return updatedLock; } } - return document; + return tree; } }; } diff --git a/rewrite-python/src/main/java/org/openrewrite/python/UpgradeDependencyVersion.java b/rewrite-python/src/main/java/org/openrewrite/python/UpgradeDependencyVersion.java index f6fe68c655..5380f02e74 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/UpgradeDependencyVersion.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/UpgradeDependencyVersion.java @@ -22,7 +22,7 @@ import org.openrewrite.python.internal.PyProjectHelper; import org.openrewrite.python.internal.PythonDependencyExecutionContextView; import org.openrewrite.python.marker.PythonResolutionResult; -import org.openrewrite.python.trait.PyProjectFile; +import org.openrewrite.python.trait.PythonDependencyFile; import org.openrewrite.toml.TomlIsoVisitor; import org.openrewrite.toml.tree.Toml; @@ -98,75 +98,74 @@ public Accumulator getInitialValue(ExecutionContext ctx) { @Override public TreeVisitor getScanner(Accumulator acc) { - return new TomlIsoVisitor() { + return new TreeVisitor() { @Override - public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) { - String sourcePath = document.getSourcePath().toString(); - - if (sourcePath.endsWith("uv.lock")) { - PythonDependencyExecutionContextView.view(ctx).getExistingLockContents().put( - PyProjectHelper.correspondingPyprojectPath(sourcePath), - document.printAll()); - return document; + public @Nullable Tree preVisit(Tree tree, ExecutionContext ctx) { + if (!(tree instanceof SourceFile)) { + return tree; } - - if (!sourcePath.endsWith("pyproject.toml")) { - return document; + stopAfterPreVisit(); + SourceFile sourceFile = (SourceFile) tree; + if (tree instanceof Toml.Document && sourceFile.getSourcePath().toString().endsWith("uv.lock")) { + PythonDependencyExecutionContextView.view(ctx).getExistingLockContents().put( + PyProjectHelper.correspondingPyprojectPath(sourceFile.getSourcePath().toString()), + ((Toml.Document) tree).printAll()); + return tree; } - Optional resolution = document.getMarkers() - .findFirst(PythonResolutionResult.class); - if (!resolution.isPresent()) { - return document; + PythonDependencyFile trait = new PythonDependencyFile.Matcher().get(getCursor()).orElse(null); + if (trait == null) { + return tree; } - - PythonResolutionResult marker = resolution.get(); - - // Check if the dependency exists in the target scope and has a different version PythonResolutionResult.Dependency dep = PyProjectHelper.findDependencyInScope( - marker, packageName, scope, groupName); + trait.getMarker(), packageName, scope, groupName); if (dep == null) { - return document; + return tree; } - - // Skip if the version constraint already matches if (PyProjectHelper.normalizeVersionConstraint(newVersion).equals(dep.getVersionConstraint())) { - return document; + return tree; } - - acc.projectsToUpdate.add(sourcePath); - return document; + acc.projectsToUpdate.add(sourceFile.getSourcePath().toString()); + return tree; } }; } @Override public TreeVisitor getVisitor(Accumulator acc) { - return new TomlIsoVisitor() { + return new TreeVisitor() { @Override - public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) { - String sourcePath = document.getSourcePath().toString(); + public @Nullable Tree preVisit(Tree tree, ExecutionContext ctx) { + if (!(tree instanceof SourceFile)) { + return tree; + } + stopAfterPreVisit(); + SourceFile sourceFile = (SourceFile) tree; + String sourcePath = sourceFile.getSourcePath().toString(); - if (sourcePath.endsWith("pyproject.toml") && acc.projectsToUpdate.contains(sourcePath)) { - PyProjectFile trait = new PyProjectFile.Matcher().get(getCursor()).orElse(null); + if (acc.projectsToUpdate.contains(sourcePath)) { + PythonDependencyFile trait = new PythonDependencyFile.Matcher().get(getCursor()).orElse(null); if (trait != null) { Map upgrades = Collections.singletonMap( PythonResolutionResult.normalizeName(packageName), newVersion); - PyProjectFile updated = trait.withUpgradedVersions(upgrades, scope, groupName); - Toml.Document result = (Toml.Document) updated.getTree(); - if (result != document) { - return PyProjectHelper.regenerateLockAndRefreshMarker(result, ctx); + PythonDependencyFile updated = trait.withUpgradedVersions(upgrades, scope, groupName); + SourceFile result = (SourceFile) updated.getTree(); + if (result != tree) { + if (result instanceof Toml.Document) { + return PyProjectHelper.regenerateLockAndRefreshMarker((Toml.Document) result, ctx); + } + return result; } } } - if (sourcePath.endsWith("uv.lock")) { - Toml.Document updatedLock = PyProjectHelper.maybeUpdateUvLock(document, ctx); + if (tree instanceof Toml.Document && sourcePath.endsWith("uv.lock")) { + Toml.Document updatedLock = PyProjectHelper.maybeUpdateUvLock((Toml.Document) tree, ctx); if (updatedLock != null) { return updatedLock; } } - return document; + return tree; } }; } diff --git a/rewrite-python/src/main/java/org/openrewrite/python/UpgradeTransitiveDependencyVersion.java b/rewrite-python/src/main/java/org/openrewrite/python/UpgradeTransitiveDependencyVersion.java index 7b6f286d97..2a6ad1427c 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/UpgradeTransitiveDependencyVersion.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/UpgradeTransitiveDependencyVersion.java @@ -19,21 +19,14 @@ import lombok.Value; import org.jspecify.annotations.Nullable; import org.openrewrite.*; -import org.openrewrite.marker.Markers; import org.openrewrite.python.internal.PyProjectHelper; import org.openrewrite.python.internal.PythonDependencyExecutionContextView; import org.openrewrite.python.marker.PythonResolutionResult; -import org.openrewrite.python.marker.PythonResolutionResult.Dependency; -import org.openrewrite.toml.TomlIsoVisitor; -import org.openrewrite.toml.tree.Space; +import org.openrewrite.python.trait.PythonDependencyFile; import org.openrewrite.toml.tree.Toml; -import org.openrewrite.toml.tree.TomlRightPadded; -import org.openrewrite.toml.tree.TomlType; import java.util.*; -import static org.openrewrite.Tree.randomId; - /** * Pin a transitive dependency version by adding or upgrading a constraint in the * appropriate tool-specific section. The strategy depends on the detected package manager: @@ -74,18 +67,8 @@ public String getDescription() { "PDM uses `[tool.pdm.overrides]`, and other managers add a direct dependency."; } - enum Action { - NONE, - ADD_CONSTRAINT, - UPGRADE_CONSTRAINT, - ADD_PDM_OVERRIDE, - UPGRADE_PDM_OVERRIDE, - ADD_DIRECT_DEPENDENCY - } - static class Accumulator { final Set projectsToUpdate = new HashSet<>(); - final Map actions = new HashMap<>(); } @Override @@ -95,317 +78,80 @@ public Accumulator getInitialValue(ExecutionContext ctx) { @Override public TreeVisitor getScanner(Accumulator acc) { - return new TomlIsoVisitor() { + return new TreeVisitor() { @Override - public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) { - String sourcePath = document.getSourcePath().toString(); - - if (sourcePath.endsWith("uv.lock")) { - PythonDependencyExecutionContextView.view(ctx).getExistingLockContents().put( - PyProjectHelper.correspondingPyprojectPath(sourcePath), - document.printAll()); - return document; + public @Nullable Tree preVisit(Tree tree, ExecutionContext ctx) { + if (!(tree instanceof SourceFile)) { + return tree; } - - if (!sourcePath.endsWith("pyproject.toml")) { - return document; + stopAfterPreVisit(); + SourceFile sourceFile = (SourceFile) tree; + if (tree instanceof Toml.Document && sourceFile.getSourcePath().toString().endsWith("uv.lock")) { + PythonDependencyExecutionContextView.view(ctx).getExistingLockContents().put( + PyProjectHelper.correspondingPyprojectPath(sourceFile.getSourcePath().toString()), + ((Toml.Document) tree).printAll()); + return tree; } - Optional resolution = document.getMarkers() - .findFirst(PythonResolutionResult.class); - if (!resolution.isPresent()) { - return document; + PythonDependencyFile trait = new PythonDependencyFile.Matcher().get(getCursor()).orElse(null); + if (trait == null) { + return tree; } - - PythonResolutionResult marker = resolution.get(); + PythonResolutionResult marker = trait.getMarker(); // Skip if this is a direct dependency if (marker.findDependency(packageName) != null) { - return document; - } - - PythonResolutionResult.PackageManager pm = marker.getPackageManager(); - Action action = null; - - if (pm == PythonResolutionResult.PackageManager.Uv) { - // Uv: require resolved deps, use constraint-dependencies - if (marker.getResolvedDependency(packageName) == null) { - return document; - } - Dependency existing = PyProjectHelper.findDependencyInScope( - marker, packageName, "tool.uv.constraint-dependencies", null); - if (existing == null) { - action = Action.ADD_CONSTRAINT; - } else if (!PyProjectHelper.normalizeVersionConstraint(version).equals(existing.getVersionConstraint())) { - action = Action.UPGRADE_CONSTRAINT; - } - } else if (pm == PythonResolutionResult.PackageManager.Pdm) { - // PDM: use tool.pdm.overrides - Dependency existing = PyProjectHelper.findDependencyInScope( - marker, packageName, "tool.pdm.overrides", null); - if (existing == null) { - action = Action.ADD_PDM_OVERRIDE; - } else if (!PyProjectHelper.normalizeVersionConstraint(version).equals(existing.getVersionConstraint())) { - action = Action.UPGRADE_PDM_OVERRIDE; - } - } else { - // Fallback: add as direct dependency - action = Action.ADD_DIRECT_DEPENDENCY; + return tree; } - if (action == null) { - return document; + // For uv: skip if not in the resolved dependency tree + if (marker.getPackageManager() == PythonResolutionResult.PackageManager.Uv && + marker.getResolvedDependency(packageName) == null) { + return tree; } - acc.actions.put(sourcePath, action); - acc.projectsToUpdate.add(sourcePath); - return document; + acc.projectsToUpdate.add(sourceFile.getSourcePath().toString()); + return tree; } }; } @Override public TreeVisitor getVisitor(Accumulator acc) { - return new TomlIsoVisitor() { + return new TreeVisitor() { @Override - public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) { - String sourcePath = document.getSourcePath().toString(); - - if (sourcePath.endsWith("pyproject.toml") && acc.projectsToUpdate.contains(sourcePath)) { - Action action = acc.actions.get(sourcePath); - if (action == Action.ADD_CONSTRAINT) { - return addToArray(document, ctx, acc, "tool.uv.constraint-dependencies"); - } else if (action == Action.UPGRADE_CONSTRAINT) { - return upgradeConstraint(document, ctx, acc); - } else if (action == Action.ADD_PDM_OVERRIDE) { - return addPdmOverride(document, ctx, acc); - } else if (action == Action.UPGRADE_PDM_OVERRIDE) { - return upgradePdmOverride(document, ctx, acc); - } else if (action == Action.ADD_DIRECT_DEPENDENCY) { - return addToArray(document, ctx, acc, null); + public @Nullable Tree preVisit(Tree tree, ExecutionContext ctx) { + if (!(tree instanceof SourceFile)) { + return tree; + } + stopAfterPreVisit(); + SourceFile sourceFile = (SourceFile) tree; + String sourcePath = sourceFile.getSourcePath().toString(); + + if (acc.projectsToUpdate.contains(sourcePath)) { + PythonDependencyFile trait = new PythonDependencyFile.Matcher().get(getCursor()).orElse(null); + if (trait != null) { + String normalizedName = PythonResolutionResult.normalizeName(packageName); + Map pins = Collections.singletonMap(normalizedName, version); + PythonDependencyFile updated = trait.withPinnedTransitiveDependencies(pins); + SourceFile result = (SourceFile) updated.getTree(); + if (result != tree) { + if (result instanceof Toml.Document) { + return PyProjectHelper.regenerateLockAndRefreshMarker((Toml.Document) result, ctx); + } + return result; + } } } - if (sourcePath.endsWith("uv.lock")) { - Toml.Document updatedLock = PyProjectHelper.maybeUpdateUvLock(document, ctx); + if (tree instanceof Toml.Document && sourcePath.endsWith("uv.lock")) { + Toml.Document updatedLock = PyProjectHelper.maybeUpdateUvLock((Toml.Document) tree, ctx); if (updatedLock != null) { return updatedLock; } } - return document; + return tree; } }; } - - private Toml.Document addToArray(Toml.Document document, ExecutionContext ctx, Accumulator acc, @Nullable String scope) { - String normalizedVersion = PyProjectHelper.normalizeVersionConstraint(version); - String pep508 = packageName + normalizedVersion; - - Toml.Document updated = (Toml.Document) new TomlIsoVisitor() { - @Override - public Toml.Array visitArray(Toml.Array array, ExecutionContext ctx) { - Toml.Array a = super.visitArray(array, ctx); - - if (!PyProjectHelper.isInsideDependencyArray(getCursor(), scope, null)) { - return a; - } - - Toml.Literal newLiteral = new Toml.Literal( - randomId(), - Space.EMPTY, - Markers.EMPTY, - TomlType.Primitive.String, - "\"" + pep508 + "\"", - pep508 - ); - - List> existingPadded = a.getPadding().getValues(); - List> newPadded = new ArrayList<>(); - - boolean isEmpty = existingPadded.size() == 1 && - existingPadded.get(0).getElement() instanceof Toml.Empty; - if (existingPadded.isEmpty() || isEmpty) { - newPadded.add(new TomlRightPadded<>(newLiteral, Space.EMPTY, Markers.EMPTY)); - } else { - TomlRightPadded lastPadded = existingPadded.get(existingPadded.size() - 1); - boolean hasTrailingComma = lastPadded.getElement() instanceof Toml.Empty; - - if (hasTrailingComma) { - int lastRealIdx = existingPadded.size() - 2; - Toml lastRealElement = existingPadded.get(lastRealIdx).getElement(); - Toml.Literal formattedLiteral = newLiteral.withPrefix(lastRealElement.getPrefix()); - - for (int i = 0; i <= lastRealIdx; i++) { - newPadded.add(existingPadded.get(i)); - } - newPadded.add(new TomlRightPadded<>(formattedLiteral, Space.EMPTY, Markers.EMPTY)); - newPadded.add(lastPadded); - } else { - Toml lastElement = lastPadded.getElement(); - Space newPrefix = lastElement.getPrefix().getWhitespace().contains("\n") - ? lastElement.getPrefix() - : Space.SINGLE_SPACE; - Toml.Literal formattedLiteral = newLiteral.withPrefix(newPrefix); - - for (int i = 0; i < existingPadded.size() - 1; i++) { - newPadded.add(existingPadded.get(i)); - } - newPadded.add(lastPadded.withAfter(Space.EMPTY)); - newPadded.add(new TomlRightPadded<>(formattedLiteral, lastPadded.getAfter(), Markers.EMPTY)); - } - } - - return a.getPadding().withValues(newPadded); - } - }.visitNonNull(document, ctx); - - if (updated != document) { - updated = PyProjectHelper.regenerateLockAndRefreshMarker(updated, ctx); - } - - return updated; - } - - private Toml.Document upgradeConstraint(Toml.Document document, ExecutionContext ctx, Accumulator acc) { - String normalizedName = PythonResolutionResult.normalizeName(packageName); - - Toml.Document updated = (Toml.Document) new TomlIsoVisitor() { - @Override - public Toml.Literal visitLiteral(Toml.Literal literal, ExecutionContext ctx) { - Toml.Literal l = super.visitLiteral(literal, ctx); - if (l.getType() != TomlType.Primitive.String) { - return l; - } - - Object val = l.getValue(); - if (!(val instanceof String)) { - return l; - } - - // Check if we're inside [tool.uv].constraint-dependencies - if (!isInsideConstraintDependencies()) { - return l; - } - - String spec = (String) val; - String depName = PyProjectHelper.extractPackageName(spec); - if (depName == null || !PythonResolutionResult.normalizeName(depName).equals(normalizedName)) { - return l; - } - - String extras = UpgradeDependencyVersion.extractExtras(spec); - String marker = UpgradeDependencyVersion.extractMarker(spec); - - StringBuilder sb = new StringBuilder(depName); - if (extras != null) { - sb.append('[').append(extras).append(']'); - } - sb.append(PyProjectHelper.normalizeVersionConstraint(version)); - if (marker != null) { - sb.append("; ").append(marker); - } - - String newSpec = sb.toString(); - return l.withSource("\"" + newSpec + "\"").withValue(newSpec); - } - - private boolean isInsideConstraintDependencies() { - Cursor c = getCursor(); - while (c != null) { - if (c.getValue() instanceof Toml.Array) { - return PyProjectHelper.isInsideDependencyArray(c, "tool.uv.constraint-dependencies", null); - } - c = c.getParent(); - } - return false; - } - }.visitNonNull(document, ctx); - - if (updated != document) { - updated = PyProjectHelper.regenerateLockAndRefreshMarker(updated, ctx); - } - - return updated; - } - - private Toml.Document addPdmOverride(Toml.Document document, ExecutionContext ctx, Accumulator acc) { - Toml.Document updated = (Toml.Document) new TomlIsoVisitor() { - @Override - public Toml.Table visitTable(Toml.Table table, ExecutionContext ctx) { - Toml.Table t = super.visitTable(table, ctx); - if (t.getName() == null || !"tool.pdm.overrides".equals(t.getName().getName())) { - return t; - } - - // Build a new KeyValue: packageName = "version" - String normalizedVersion = PyProjectHelper.normalizeVersionConstraint(version); - Toml.Identifier key = new Toml.Identifier( - randomId(), Space.EMPTY, Markers.EMPTY, packageName, packageName); - Toml.Literal value = new Toml.Literal( - randomId(), Space.SINGLE_SPACE, Markers.EMPTY, - TomlType.Primitive.String, "\"" + normalizedVersion + "\"", normalizedVersion); - Toml.KeyValue newKv = new Toml.KeyValue( - randomId(), Space.EMPTY, Markers.EMPTY, - new TomlRightPadded<>(key, Space.SINGLE_SPACE, Markers.EMPTY), - value); - - // Determine prefix for new entry - List values = t.getValues(); - Space entryPrefix; - if (!values.isEmpty()) { - entryPrefix = values.get(values.size() - 1).getPrefix(); - } else { - entryPrefix = Space.format("\n"); - } - newKv = newKv.withPrefix(entryPrefix); - - List newValues = new ArrayList<>(values); - newValues.add(newKv); - return t.withValues(newValues); - } - }.visitNonNull(document, ctx); - - if (updated != document) { - updated = PyProjectHelper.regenerateLockAndRefreshMarker(updated, ctx); - } - - return updated; - } - - private Toml.Document upgradePdmOverride(Toml.Document document, ExecutionContext ctx, Accumulator acc) { - String normalizedName = PythonResolutionResult.normalizeName(packageName); - - Toml.Document updated = (Toml.Document) new TomlIsoVisitor() { - @Override - public Toml.KeyValue visitKeyValue(Toml.KeyValue keyValue, ExecutionContext ctx) { - Toml.KeyValue kv = super.visitKeyValue(keyValue, ctx); - - if (!PyProjectHelper.isInsidePdmOverridesTable(getCursor())) { - return kv; - } - - if (!(kv.getKey() instanceof Toml.Identifier)) { - return kv; - } - String keyName = ((Toml.Identifier) kv.getKey()).getName(); - if (!PythonResolutionResult.normalizeName(keyName).equals(normalizedName)) { - return kv; - } - - if (!(kv.getValue() instanceof Toml.Literal)) { - return kv; - } - - Toml.Literal literal = (Toml.Literal) kv.getValue(); - String normalizedVersion = PyProjectHelper.normalizeVersionConstraint(version); - return kv.withValue(literal.withSource("\"" + normalizedVersion + "\"").withValue(normalizedVersion)); - } - }.visitNonNull(document, ctx); - - if (updated != document) { - updated = PyProjectHelper.regenerateLockAndRefreshMarker(updated, ctx); - } - - return updated; - } } diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java index 167ba75b4c..81b1a36147 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java @@ -21,9 +21,7 @@ import org.openrewrite.toml.tree.TomlType; import org.openrewrite.trait.SimpleTraitMatcher; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; +import java.util.*; @Value public class PyProjectFile implements PythonDependencyFile { @@ -90,6 +88,115 @@ public PyProjectFile withAddedDependencies(Map additions, @Nulla return this; } + @Override + public PyProjectFile withPinnedTransitiveDependencies(Map pins) { + PythonResolutionResult.PackageManager pm = marker.getPackageManager(); + if (pm == PythonResolutionResult.PackageManager.Uv) { + return pinViaArrayScope(pins, "tool.uv.constraint-dependencies"); + } else if (pm == PythonResolutionResult.PackageManager.Pdm) { + return pinViaPdmOverrides(pins); + } else { + // Fallback: add as direct dependency + return withAddedDependencies(pins, null, null); + } + } + + private PyProjectFile pinViaArrayScope(Map pins, String scope) { + PyProjectFile current = this; + for (Map.Entry entry : pins.entrySet()) { + PythonResolutionResult.Dependency existing = PyProjectHelper.findDependencyInScope( + current.marker, entry.getKey(), scope, null); + if (existing == null) { + current = current.withAddedDependencies( + Collections.singletonMap(entry.getKey(), entry.getValue()), scope, null); + } else { + current = current.withUpgradedVersions( + Collections.singletonMap(entry.getKey(), entry.getValue()), scope, null); + } + } + return current; + } + + private PyProjectFile pinViaPdmOverrides(Map pins) { + Toml.Document doc = (Toml.Document) getTree(); + Toml.Document updated = doc; + for (Map.Entry entry : pins.entrySet()) { + PythonResolutionResult.Dependency existing = PyProjectHelper.findDependencyInScope( + marker, entry.getKey(), "tool.pdm.overrides", null); + if (existing == null) { + updated = addPdmOverride(updated, entry.getKey(), entry.getValue()); + } else { + updated = upgradePdmOverride(updated, entry.getKey(), entry.getValue()); + } + } + if (updated != doc) { + return new PyProjectFile(new Cursor(cursor.getParentOrThrow(), updated), marker); + } + return this; + } + + private static Toml.Document addPdmOverride(Toml.Document doc, String packageName, String version) { + String normalizedVersion = PyProjectHelper.normalizeVersionConstraint(version); + return (Toml.Document) new TomlIsoVisitor() { + @Override + public Toml.Table visitTable(Toml.Table table, Integer p) { + Toml.Table t = super.visitTable(table, p); + if (t.getName() == null || !"tool.pdm.overrides".equals(t.getName().getName())) { + return t; + } + + Toml.Identifier key = new Toml.Identifier( + Tree.randomId(), Space.EMPTY, Markers.EMPTY, packageName, packageName); + Toml.Literal value = new Toml.Literal( + Tree.randomId(), Space.SINGLE_SPACE, Markers.EMPTY, + TomlType.Primitive.String, "\"" + normalizedVersion + "\"", normalizedVersion); + Toml.KeyValue newKv = new Toml.KeyValue( + Tree.randomId(), Space.EMPTY, Markers.EMPTY, + new TomlRightPadded<>(key, Space.SINGLE_SPACE, Markers.EMPTY), + value); + + List values = t.getValues(); + Space entryPrefix = !values.isEmpty() + ? values.get(values.size() - 1).getPrefix() + : Space.format("\n"); + newKv = newKv.withPrefix(entryPrefix); + + List newValues = new ArrayList<>(values); + newValues.add(newKv); + return t.withValues(newValues); + } + }.visitNonNull(doc, 0); + } + + private static Toml.Document upgradePdmOverride(Toml.Document doc, String packageName, String version) { + String normalizedName = PythonResolutionResult.normalizeName(packageName); + String normalizedVersion = PyProjectHelper.normalizeVersionConstraint(version); + return (Toml.Document) new TomlIsoVisitor() { + @Override + public Toml.KeyValue visitKeyValue(Toml.KeyValue keyValue, Integer p) { + Toml.KeyValue kv = super.visitKeyValue(keyValue, p); + if (!PyProjectHelper.isInsidePdmOverridesTable(getCursor())) { + return kv; + } + if (!(kv.getKey() instanceof Toml.Identifier)) { + return kv; + } + String keyName = ((Toml.Identifier) kv.getKey()).getName(); + if (!PythonResolutionResult.normalizeName(keyName).equals(normalizedName)) { + return kv; + } + if (!(kv.getValue() instanceof Toml.Literal)) { + return kv; + } + Toml.Literal literal = (Toml.Literal) kv.getValue(); + if (normalizedVersion.equals(literal.getValue())) { + return kv; + } + return kv.withValue(literal.withSource("\"" + normalizedVersion + "\"").withValue(normalizedVersion)); + } + }.visitNonNull(doc, 0); + } + private static Toml.Document addDependencyToArray(Toml.Document d, String pep508, @Nullable String scope, @Nullable String groupName) { return (Toml.Document) new TomlIsoVisitor() { diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java index 4ff72ca724..c8bec7287c 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java @@ -49,6 +49,16 @@ public interface PythonDependencyFile extends Trait { */ PythonDependencyFile withAddedDependencies(Map additions, @Nullable String scope, @Nullable String groupName); + /** + * Pin transitive dependencies using the strategy appropriate for this file's + * package manager. For pyproject.toml: uv uses {@code [tool.uv].constraint-dependencies}, + * PDM uses {@code [tool.pdm.overrides]}, and other managers add a direct dependency. + * For requirements.txt: appends the dependency. + * + * @param pins normalized package name → version constraint + */ + PythonDependencyFile withPinnedTransitiveDependencies(Map pins); + /** * Add search result markers for vulnerable dependencies. * diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java index 0adfd8164f..de9954325d 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java @@ -109,6 +109,11 @@ public RequirementsFile withAddedDependencies(Map additions, @Nu return this; } + @Override + public RequirementsFile withPinnedTransitiveDependencies(Map pins) { + return withAddedDependencies(pins, null, null); + } + @Override public RequirementsFile withDependencySearchMarkers(Map packageMessages, ExecutionContext ctx) { PlainText result = (PlainText) getTree(); From 3013e787c626f351df7c75554a0843b178ff9706 Mon Sep 17 00:00:00 2001 From: Jente Sondervorst Date: Fri, 3 Apr 2026 17:41:03 +0200 Subject: [PATCH 05/10] Also use the trait for RemoveDependency and ChangeDependency --- .../openrewrite/python/ChangeDependency.java | 154 +++++------------- .../openrewrite/python/RemoveDependency.java | 143 +++++----------- .../python/trait/PyProjectFile.java | 131 ++++++++++++++- .../python/trait/PythonDependencyFile.java | 44 ++++- .../python/trait/RequirementsFile.java | 71 +++++++- 5 files changed, 317 insertions(+), 226 deletions(-) diff --git a/rewrite-python/src/main/java/org/openrewrite/python/ChangeDependency.java b/rewrite-python/src/main/java/org/openrewrite/python/ChangeDependency.java index d945fb7978..c4086e8e76 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/ChangeDependency.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/ChangeDependency.java @@ -18,16 +18,11 @@ import lombok.EqualsAndHashCode; import lombok.Value; import org.jspecify.annotations.Nullable; -import org.openrewrite.ExecutionContext; -import org.openrewrite.Option; -import org.openrewrite.ScanningRecipe; -import org.openrewrite.TreeVisitor; +import org.openrewrite.*; import org.openrewrite.python.internal.PyProjectHelper; import org.openrewrite.python.internal.PythonDependencyExecutionContextView; -import org.openrewrite.python.marker.PythonResolutionResult; -import org.openrewrite.toml.TomlIsoVisitor; +import org.openrewrite.python.trait.PythonDependencyFile; import org.openrewrite.toml.tree.Toml; -import org.openrewrite.toml.tree.TomlType; import java.util.*; @@ -84,130 +79,67 @@ public Accumulator getInitialValue(ExecutionContext ctx) { @Override public TreeVisitor getScanner(Accumulator acc) { - return new TomlIsoVisitor() { + return new TreeVisitor() { @Override - public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) { - String sourcePath = document.getSourcePath().toString(); - - if (sourcePath.endsWith("uv.lock")) { - PythonDependencyExecutionContextView.view(ctx).getExistingLockContents().put( - PyProjectHelper.correspondingPyprojectPath(sourcePath), - document.printAll()); - return document; + public @Nullable Tree preVisit(Tree tree, ExecutionContext ctx) { + if (!(tree instanceof SourceFile)) { + return tree; } - - if (!sourcePath.endsWith("pyproject.toml")) { - return document; + stopAfterPreVisit(); + SourceFile sourceFile = (SourceFile) tree; + if (tree instanceof Toml.Document && sourceFile.getSourcePath().toString().endsWith("uv.lock")) { + PythonDependencyExecutionContextView.view(ctx).getExistingLockContents().put( + PyProjectHelper.correspondingPyprojectPath(sourceFile.getSourcePath().toString()), + ((Toml.Document) tree).printAll()); + return tree; } - Optional resolution = document.getMarkers() - .findFirst(PythonResolutionResult.class); - if (!resolution.isPresent()) { - return document; + PythonDependencyFile trait = new PythonDependencyFile.Matcher().get(getCursor()).orElse(null); + if (trait == null) { + return tree; } - - PythonResolutionResult marker = resolution.get(); - if (marker.findDependencyInAnyScope(oldPackageName) != null) { - acc.projectsToUpdate.add(sourcePath); + if (trait.getMarker().findDependencyInAnyScope(oldPackageName) != null) { + acc.projectsToUpdate.add(sourceFile.getSourcePath().toString()); } - return document; + return tree; } }; } @Override public TreeVisitor getVisitor(Accumulator acc) { - return new TomlIsoVisitor() { + return new TreeVisitor() { @Override - public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) { - String sourcePath = document.getSourcePath().toString(); - - if (sourcePath.endsWith("pyproject.toml") && acc.projectsToUpdate.contains(sourcePath)) { - return changeDependencyInPyproject(document, ctx, acc); + public @Nullable Tree preVisit(Tree tree, ExecutionContext ctx) { + if (!(tree instanceof SourceFile)) { + return tree; + } + stopAfterPreVisit(); + SourceFile sourceFile = (SourceFile) tree; + String sourcePath = sourceFile.getSourcePath().toString(); + + if (acc.projectsToUpdate.contains(sourcePath)) { + PythonDependencyFile trait = new PythonDependencyFile.Matcher().get(getCursor()).orElse(null); + if (trait != null) { + PythonDependencyFile updated = trait.withChangedDependency(oldPackageName, newPackageName, newVersion); + SourceFile result = (SourceFile) updated.getTree(); + if (result != tree) { + if (result instanceof Toml.Document) { + return PyProjectHelper.regenerateLockAndRefreshMarker((Toml.Document) result, ctx); + } + return result; + } + } } - if (sourcePath.endsWith("uv.lock")) { - Toml.Document updatedLock = PyProjectHelper.maybeUpdateUvLock(document, ctx); + if (tree instanceof Toml.Document && sourcePath.endsWith("uv.lock")) { + Toml.Document updatedLock = PyProjectHelper.maybeUpdateUvLock((Toml.Document) tree, ctx); if (updatedLock != null) { return updatedLock; } } - return document; + return tree; } }; } - - private Toml.Document changeDependencyInPyproject(Toml.Document document, ExecutionContext ctx, Accumulator acc) { - String normalizedOld = PythonResolutionResult.normalizeName(oldPackageName); - - Toml.Document updated = (Toml.Document) new TomlIsoVisitor() { - @Override - public Toml.Literal visitLiteral(Toml.Literal literal, ExecutionContext ctx) { - Toml.Literal l = super.visitLiteral(literal, ctx); - if (l.getType() != TomlType.Primitive.String) { - return l; - } - - Object val = l.getValue(); - if (!(val instanceof String)) { - return l; - } - - String spec = (String) val; - String depName = PyProjectHelper.extractPackageName(spec); - if (depName == null || !PythonResolutionResult.normalizeName(depName).equals(normalizedOld)) { - return l; - } - - // Build new PEP 508 string - String extras = UpgradeDependencyVersion.extractExtras(spec); - String marker = UpgradeDependencyVersion.extractMarker(spec); - - StringBuilder sb = new StringBuilder(newPackageName); - if (extras != null) { - sb.append('[').append(extras).append(']'); - } - if (newVersion != null) { - sb.append(PyProjectHelper.normalizeVersionConstraint(newVersion)); - } else { - // Preserve the original version constraint - String originalVersion = extractVersionConstraint(spec, depName); - if (originalVersion != null) { - sb.append(originalVersion); - } - } - if (marker != null) { - sb.append("; ").append(marker); - } - - String newSpec = sb.toString(); - return l.withSource("\"" + newSpec + "\"").withValue(newSpec); - } - }.visitNonNull(document, ctx); - - if (updated != document) { - updated = PyProjectHelper.regenerateLockAndRefreshMarker(updated, ctx); - } - - return updated; - } - - /** - * Extract the version constraint portion from a PEP 508 spec. - * Returns the version constraint (e.g. ">=2.28.0") or null if there is none. - */ - private static @Nullable String extractVersionConstraint(String spec, String name) { - String remaining = spec.substring(name.length()).trim(); - // Skip extras [...] - if (remaining.startsWith("[")) { - int end = remaining.indexOf(']'); - if (end >= 0) { - remaining = remaining.substring(end + 1).trim(); - } - } - // Extract version constraint up to marker - int markerIdx = remaining.indexOf(';'); - String versionPart = markerIdx >= 0 ? remaining.substring(0, markerIdx).trim() : remaining.trim(); - return versionPart.isEmpty() ? null : versionPart; - } } diff --git a/rewrite-python/src/main/java/org/openrewrite/python/RemoveDependency.java b/rewrite-python/src/main/java/org/openrewrite/python/RemoveDependency.java index 0b62c3f411..0bc3e23f6b 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/RemoveDependency.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/RemoveDependency.java @@ -21,11 +21,8 @@ import org.openrewrite.*; import org.openrewrite.python.internal.PyProjectHelper; import org.openrewrite.python.internal.PythonDependencyExecutionContextView; -import org.openrewrite.python.marker.PythonResolutionResult; -import org.openrewrite.toml.TomlIsoVisitor; -import org.openrewrite.toml.tree.Space; +import org.openrewrite.python.trait.PythonDependencyFile; import org.openrewrite.toml.tree.Toml; -import org.openrewrite.toml.tree.TomlRightPadded; import java.util.*; @@ -94,124 +91,70 @@ public Accumulator getInitialValue(ExecutionContext ctx) { @Override public TreeVisitor getScanner(Accumulator acc) { - return new TomlIsoVisitor() { + return new TreeVisitor() { @Override - public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) { - String sourcePath = document.getSourcePath().toString(); - - if (sourcePath.endsWith("uv.lock")) { - PythonDependencyExecutionContextView.view(ctx).getExistingLockContents().put( - PyProjectHelper.correspondingPyprojectPath(sourcePath), - document.printAll()); - return document; + public @Nullable Tree preVisit(Tree tree, ExecutionContext ctx) { + if (!(tree instanceof SourceFile)) { + return tree; } - - if (!sourcePath.endsWith("pyproject.toml")) { - return document; + stopAfterPreVisit(); + SourceFile sourceFile = (SourceFile) tree; + if (tree instanceof Toml.Document && sourceFile.getSourcePath().toString().endsWith("uv.lock")) { + PythonDependencyExecutionContextView.view(ctx).getExistingLockContents().put( + PyProjectHelper.correspondingPyprojectPath(sourceFile.getSourcePath().toString()), + ((Toml.Document) tree).printAll()); + return tree; } - Optional resolution = document.getMarkers() - .findFirst(PythonResolutionResult.class); - if (!resolution.isPresent()) { - return document; + PythonDependencyFile trait = new PythonDependencyFile.Matcher().get(getCursor()).orElse(null); + if (trait == null) { + return tree; } - - PythonResolutionResult marker = resolution.get(); - - // Check if the dependency exists in the target scope - if (PyProjectHelper.findDependencyInScope(marker, packageName, scope, groupName) == null) { - return document; + if (PyProjectHelper.findDependencyInScope(trait.getMarker(), packageName, scope, groupName) == null) { + return tree; } - - acc.projectsToUpdate.add(sourcePath); - return document; + acc.projectsToUpdate.add(sourceFile.getSourcePath().toString()); + return tree; } }; } @Override public TreeVisitor getVisitor(Accumulator acc) { - return new TomlIsoVisitor() { - @Override - public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) { - String sourcePath = document.getSourcePath().toString(); - - if (sourcePath.endsWith("pyproject.toml") && acc.projectsToUpdate.contains(sourcePath)) { - return removeDependencyFromPyproject(document, ctx, acc); - } - - if (sourcePath.endsWith("uv.lock")) { - Toml.Document updatedLock = PyProjectHelper.maybeUpdateUvLock(document, ctx); - if (updatedLock != null) { - return updatedLock; - } - } - - return document; - } - }; - } - - private Toml.Document removeDependencyFromPyproject(Toml.Document document, ExecutionContext ctx, Accumulator acc) { - String normalizedName = PythonResolutionResult.normalizeName(packageName); - - Toml.Document updated = (Toml.Document) new TomlIsoVisitor() { + return new TreeVisitor() { @Override - public Toml.Array visitArray(Toml.Array array, ExecutionContext ctx) { - Toml.Array a = super.visitArray(array, ctx); - - if (!PyProjectHelper.isInsideDependencyArray(getCursor(), scope, groupName)) { - return a; + public @Nullable Tree preVisit(Tree tree, ExecutionContext ctx) { + if (!(tree instanceof SourceFile)) { + return tree; } - - // Find and remove the matching dependency - List> existingPadded = a.getPadding().getValues(); - List> newPadded = new ArrayList<>(); - boolean found = false; - int removedIdx = -1; - - for (int i = 0; i < existingPadded.size(); i++) { - TomlRightPadded padded = existingPadded.get(i); - Toml element = padded.getElement(); - - if (!found && element instanceof Toml.Literal) { - Object val = ((Toml.Literal) element).getValue(); - if (val instanceof String) { - String depName = PyProjectHelper.extractPackageName((String) val); - if (depName != null && PythonResolutionResult.normalizeName(depName).equals(normalizedName)) { - found = true; - removedIdx = i; - continue; + stopAfterPreVisit(); + SourceFile sourceFile = (SourceFile) tree; + String sourcePath = sourceFile.getSourcePath().toString(); + + if (acc.projectsToUpdate.contains(sourcePath)) { + PythonDependencyFile trait = new PythonDependencyFile.Matcher().get(getCursor()).orElse(null); + if (trait != null) { + PythonDependencyFile updated = trait.withRemovedDependencies( + Collections.singleton(packageName), scope, groupName); + SourceFile result = (SourceFile) updated.getTree(); + if (result != tree) { + if (result instanceof Toml.Document) { + return PyProjectHelper.regenerateLockAndRefreshMarker((Toml.Document) result, ctx); } + return result; } } - - newPadded.add(padded); } - if (!found) { - return a; - } - - // If the removed element was the first one, the next element - // may have a space prefix from comma formatting. Transfer the - // removed element's prefix to the first remaining real element. - if (removedIdx == 0 && !newPadded.isEmpty()) { - TomlRightPadded first = newPadded.get(0); - if (!(first.getElement() instanceof Toml.Empty)) { - Space originalPrefix = existingPadded.get(removedIdx).getElement().getPrefix(); - newPadded.set(0, first.map(el -> el.withPrefix(originalPrefix))); + if (tree instanceof Toml.Document && sourcePath.endsWith("uv.lock")) { + Toml.Document updatedLock = PyProjectHelper.maybeUpdateUvLock((Toml.Document) tree, ctx); + if (updatedLock != null) { + return updatedLock; } } - return a.getPadding().withValues(newPadded); + return tree; } - }.visitNonNull(document, ctx); - - if (updated != document) { - updated = PyProjectHelper.regenerateLockAndRefreshMarker(updated, ctx); - } - - return updated; + }; } } diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java index 81b1a36147..89bf10ef71 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java @@ -45,8 +45,7 @@ public Toml.Literal visitLiteral(Toml.Literal literal, Map u) { return literal; } - String normalizedName = PythonResolutionResult.normalizeName(packageName); - String fixVersion = u.get(normalizedName); + String fixVersion = PythonDependencyFile.getByNormalizedName(u, packageName); if (fixVersion == null) { return literal; } @@ -88,6 +87,134 @@ public PyProjectFile withAddedDependencies(Map additions, @Nulla return this; } + @Override + public PyProjectFile withRemovedDependencies(Set packageNames, @Nullable String scope, @Nullable String groupName) { + Set normalizedNames = new HashSet<>(); + for (String name : packageNames) { + normalizedNames.add(PythonResolutionResult.normalizeName(name)); + } + Toml.Document doc = (Toml.Document) getTree(); + Toml.Document result = (Toml.Document) new TomlIsoVisitor>() { + @Override + public Toml.Array visitArray(Toml.Array array, Set names) { + Toml.Array a = super.visitArray(array, names); + if (!PyProjectHelper.isInsideDependencyArray(getCursor(), scope, groupName)) { + return a; + } + + List> existingPadded = a.getPadding().getValues(); + List> newPadded = new ArrayList<>(); + boolean found = false; + int removedIdx = -1; + + for (int i = 0; i < existingPadded.size(); i++) { + TomlRightPadded padded = existingPadded.get(i); + Toml element = padded.getElement(); + + if (element instanceof Toml.Literal) { + Object val = ((Toml.Literal) element).getValue(); + if (val instanceof String) { + String depName = PyProjectHelper.extractPackageName((String) val); + if (depName != null && names.contains(PythonResolutionResult.normalizeName(depName))) { + if (!found) { + removedIdx = i; + } + found = true; + continue; + } + } + } + newPadded.add(padded); + } + + if (!found) { + return a; + } + + if (removedIdx == 0 && !newPadded.isEmpty()) { + TomlRightPadded first = newPadded.get(0); + if (!(first.getElement() instanceof Toml.Empty)) { + Space originalPrefix = existingPadded.get(removedIdx).getElement().getPrefix(); + newPadded.set(0, first.map(el -> el.withPrefix(originalPrefix))); + } + } + + return a.getPadding().withValues(newPadded); + } + }.visitNonNull(doc, normalizedNames); + if (result != doc) { + return new PyProjectFile(new Cursor(cursor.getParentOrThrow(), result), marker); + } + return this; + } + + @Override + public PyProjectFile withChangedDependency(String oldPackageName, String newPackageName, @Nullable String newVersion) { + String normalizedOld = PythonResolutionResult.normalizeName(oldPackageName); + Toml.Document doc = (Toml.Document) getTree(); + Toml.Document result = (Toml.Document) new TomlIsoVisitor() { + @Override + public Toml.Literal visitLiteral(Toml.Literal literal, Integer p) { + if (literal.getType() != TomlType.Primitive.String) { + return literal; + } + Object val = literal.getValue(); + if (!(val instanceof String)) { + return literal; + } + String spec = (String) val; + String depName = PyProjectHelper.extractPackageName(spec); + if (depName == null || !PythonResolutionResult.normalizeName(depName).equals(normalizedOld)) { + return literal; + } + + String newSpec = buildChangedSpec(spec, depName, newPackageName, newVersion); + return literal.withSource("\"" + newSpec + "\"").withValue(newSpec); + } + }.visitNonNull(doc, 0); + if (result != doc) { + return new PyProjectFile(new Cursor(cursor.getParentOrThrow(), result), marker); + } + return this; + } + + private static String buildChangedSpec(String oldSpec, String oldName, String newName, @Nullable String newVer) { + StringBuilder sb = new StringBuilder(newName); + + // Preserve extras + int extrasStart = oldSpec.indexOf('[', oldName.length()); + int extrasEnd = oldSpec.indexOf(']', oldName.length()); + if (extrasStart >= 0 && extrasEnd > extrasStart) { + sb.append(oldSpec, extrasStart, extrasEnd + 1); + } + + if (newVer != null) { + sb.append(PyProjectHelper.normalizeVersionConstraint(newVer)); + } else { + // Preserve original version constraint + String remaining = oldSpec.substring(oldName.length()).trim(); + if (remaining.startsWith("[")) { + int end = remaining.indexOf(']'); + if (end >= 0) { + remaining = remaining.substring(end + 1).trim(); + } + } + int markerIdx = remaining.indexOf(';'); + String versionPart = markerIdx >= 0 ? remaining.substring(0, markerIdx).trim() : remaining.trim(); + if (!versionPart.isEmpty()) { + sb.append(versionPart); + } + } + + // Preserve environment markers + int semiIdx = oldSpec.indexOf(';'); + if (semiIdx >= 0) { + sb.append("; ").append(oldSpec.substring(semiIdx + 1).trim()); + } + + return sb.toString(); + } + @Override public PyProjectFile withPinnedTransitiveDependencies(Map pins) { PythonResolutionResult.PackageManager pm = marker.getPackageManager(); diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java index c8bec7287c..782a8399a1 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java @@ -14,9 +14,7 @@ import org.openrewrite.trait.SimpleTraitMatcher; import org.openrewrite.trait.Trait; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; +import java.util.*; /** * Trait for Python dependency files (pyproject.toml, requirements.txt, etc.). @@ -31,7 +29,7 @@ public interface PythonDependencyFile extends Trait { /** * Upgrade version constraints for dependencies in the specified scope. * - * @param upgrades normalized package name → new version + * @param upgrades package name → new version * @param scope the TOML scope, or {@code null} for the default ({@code [project].dependencies}) * @param groupName required for {@code "project.optional-dependencies"} or {@code "dependency-groups"} */ @@ -40,7 +38,7 @@ public interface PythonDependencyFile extends Trait { /** * Add dependencies to the specified scope. * - * @param additions normalized package name → version constraint (e.g. {@code "2.0"} or {@code ">=2.0"}) + * @param additions package name → version constraint (e.g. {@code "2.0"} or {@code ">=2.0"}) * @param scope the TOML scope (e.g. {@code "project.optional-dependencies"}, * {@code "dependency-groups"}), or {@code null} for the default * ({@code [project].dependencies}) @@ -55,14 +53,32 @@ public interface PythonDependencyFile extends Trait { * PDM uses {@code [tool.pdm.overrides]}, and other managers add a direct dependency. * For requirements.txt: appends the dependency. * - * @param pins normalized package name → version constraint + * @param pins package name → version constraint */ PythonDependencyFile withPinnedTransitiveDependencies(Map pins); + /** + * Remove dependencies from the specified scope. + * + * @param packageNames package names to remove + * @param scope the TOML scope, or {@code null} for the default ({@code [project].dependencies}) + * @param groupName required for {@code "project.optional-dependencies"} or {@code "dependency-groups"} + */ + PythonDependencyFile withRemovedDependencies(Set packageNames, @Nullable String scope, @Nullable String groupName); + + /** + * Change a dependency to a different package, searching all scopes. + * + * @param oldPackageName the current package name + * @param newPackageName the new package name + * @param newVersion optional new version constraint, or {@code null} to preserve the original + */ + PythonDependencyFile withChangedDependency(String oldPackageName, String newPackageName, @Nullable String newVersion); + /** * Add search result markers for vulnerable dependencies. * - * @param packageMessages normalized package name → vulnerability description message + * @param packageMessages package name → vulnerability description message */ PythonDependencyFile withDependencySearchMarkers(Map packageMessages, ExecutionContext ctx); @@ -97,6 +113,20 @@ static String rewritePep508Spec(String spec, String packageName, String newVersi return sb.toString(); } + /** + * Look up a value in a map by normalizing the lookup key per PEP 503. + * This allows callers to pass non-normalized package names. + */ + static @Nullable String getByNormalizedName(Map map, String name) { + String normalized = PythonResolutionResult.normalizeName(name); + for (Map.Entry entry : map.entrySet()) { + if (PythonResolutionResult.normalizeName(entry.getKey()).equals(normalized)) { + return entry.getValue(); + } + } + return null; + } + /** * Update the resolved dependency versions in a marker to reflect version changes. * Returns the same marker if no changes were needed. diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java index de9954325d..e7703f7a0f 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java @@ -16,9 +16,7 @@ import org.openrewrite.text.PlainText; import org.openrewrite.trait.SimpleTraitMatcher; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; +import java.util.*; @Value public class RequirementsFile implements PythonDependencyFile { @@ -46,8 +44,7 @@ public RequirementsFile withUpgradedVersions(Map upgrades, @Null continue; } - String normalizedName = PythonResolutionResult.normalizeName(packageName); - String fixVersion = upgrades.get(normalizedName); + String fixVersion = PythonDependencyFile.getByNormalizedName(upgrades, packageName); if (fixVersion == null) { continue; } @@ -92,7 +89,7 @@ public RequirementsFile withAddedDependencies(Map additions, @Nu StringBuilder sb = new StringBuilder(text); boolean changed = false; for (Map.Entry entry : additions.entrySet()) { - if (!existingPackages.contains(entry.getKey())) { + if (!existingPackages.contains(PythonResolutionResult.normalizeName(entry.getKey()))) { sb.append("\n").append(entry.getKey()).append(PyProjectHelper.normalizeVersionConstraint(entry.getValue())); changed = true; } @@ -109,6 +106,68 @@ public RequirementsFile withAddedDependencies(Map additions, @Nu return this; } + @Override + public RequirementsFile withRemovedDependencies(Set packageNames, @Nullable String scope, @Nullable String groupName) { + Set normalizedNames = new HashSet<>(); + for (String name : packageNames) { + normalizedNames.add(PythonResolutionResult.normalizeName(name)); + } + PlainText pt = (PlainText) getTree(); + String[] lines = pt.getText().split("\n", -1); + List kept = new ArrayList<>(); + boolean changed = false; + + for (String line : lines) { + String pkg = PyProjectHelper.extractPackageName(line.trim()); + if (pkg != null && normalizedNames.contains(PythonResolutionResult.normalizeName(pkg))) { + changed = true; + } else { + kept.add(line); + } + } + + if (changed) { + PlainText newPt = pt.withText(String.join("\n", kept)); + return new RequirementsFile(new Cursor(cursor.getParentOrThrow(), newPt), marker); + } + return this; + } + + @Override + public RequirementsFile withChangedDependency(String oldPackageName, String newPackageName, @Nullable String newVersion) { + PlainText pt = (PlainText) getTree(); + String[] lines = pt.getText().split("\n", -1); + boolean changed = false; + String normalizedOld = PythonResolutionResult.normalizeName(oldPackageName); + + for (int i = 0; i < lines.length; i++) { + String trimmed = lines[i].trim(); + String pkg = PyProjectHelper.extractPackageName(trimmed); + if (pkg != null && PythonResolutionResult.normalizeName(pkg).equals(normalizedOld)) { + // Preserve leading whitespace + int leadingWs = 0; + while (leadingWs < lines[i].length() && Character.isWhitespace(lines[i].charAt(leadingWs))) { + leadingWs++; + } + String newSpec; + if (newVersion != null) { + newSpec = newPackageName + PyProjectHelper.normalizeVersionConstraint(newVersion); + } else { + // Replace just the name, keep the rest + newSpec = newPackageName + trimmed.substring(pkg.length()); + } + lines[i] = lines[i].substring(0, leadingWs) + newSpec; + changed = true; + } + } + + if (changed) { + PlainText newPt = pt.withText(String.join("\n", lines)); + return new RequirementsFile(new Cursor(cursor.getParentOrThrow(), newPt), marker); + } + return this; + } + @Override public RequirementsFile withPinnedTransitiveDependencies(Map pins) { return withAddedDependencies(pins, null, null); From f495f4d33b54ec6cd7b481dc97df4e485c7170c8 Mon Sep 17 00:00:00 2001 From: Jente Sondervorst Date: Fri, 3 Apr 2026 20:09:10 +0200 Subject: [PATCH 06/10] Improved docs --- .../org/openrewrite/python/AddDependency.java | 17 +++++---- .../openrewrite/python/ChangeDependency.java | 12 ++++--- .../openrewrite/python/RemoveDependency.java | 10 ++++-- .../python/UpgradeDependencyVersion.java | 31 +++++----------- .../UpgradeTransitiveDependencyVersion.java | 26 +++++++------- .../python/trait/RequirementsFile.java | 35 +++++++++++++++++++ 6 files changed, 82 insertions(+), 49 deletions(-) diff --git a/rewrite-python/src/main/java/org/openrewrite/python/AddDependency.java b/rewrite-python/src/main/java/org/openrewrite/python/AddDependency.java index ebb9b1f2ca..fd4f10bb1b 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/AddDependency.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/AddDependency.java @@ -21,15 +21,17 @@ import org.openrewrite.*; import org.openrewrite.python.internal.PyProjectHelper; import org.openrewrite.python.internal.PythonDependencyExecutionContextView; -import org.openrewrite.python.marker.PythonResolutionResult; import org.openrewrite.python.trait.PythonDependencyFile; -import org.openrewrite.toml.TomlIsoVisitor; import org.openrewrite.toml.tree.Toml; -import java.util.*; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; /** - * Add a dependency to the {@code [project].dependencies} array in pyproject.toml. + * Add a dependency to a Python project. Supports both {@code pyproject.toml} + * (with scope and group targeting) and {@code requirements.txt} files. * When uv is available, the uv.lock file is regenerated to reflect the change. */ @EqualsAndHashCode(callSuper = false) @@ -49,7 +51,9 @@ public class AddDependency extends ScanningRecipe { String version; @Option(displayName = "Scope", - description = "The dependency scope to add to. Defaults to `project.dependencies`.", + description = "The dependency scope to add to. For pyproject.toml this targets a specific TOML section. " + + "For requirements files, `null` matches all files, empty string matches only `requirements.txt`, " + + "and a value like `dev` matches `requirements-dev.txt`. Defaults to `project.dependencies`.", valid = {"project.dependencies", "project.optional-dependencies", "dependency-groups", "tool.uv.constraint-dependencies", "tool.uv.override-dependencies"}, example = "project.dependencies", @@ -85,7 +89,8 @@ public String getInstanceNameSuffix() { @Override public String getDescription() { - return "Add a dependency to the `[project].dependencies` array in `pyproject.toml`. " + + return "Add a dependency to a Python project. Supports `pyproject.toml` " + + "(with scope/group targeting) and `requirements.txt` files. " + "When `uv` is available, the `uv.lock` file is regenerated."; } diff --git a/rewrite-python/src/main/java/org/openrewrite/python/ChangeDependency.java b/rewrite-python/src/main/java/org/openrewrite/python/ChangeDependency.java index c4086e8e76..03e70e2829 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/ChangeDependency.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/ChangeDependency.java @@ -24,11 +24,12 @@ import org.openrewrite.python.trait.PythonDependencyFile; import org.openrewrite.toml.tree.Toml; -import java.util.*; +import java.util.HashSet; +import java.util.Set; /** - * Change a dependency to a different package in pyproject.toml. - * Searches all dependency arrays in the document (no scope restriction). + * Change a dependency to a different package. Supports both {@code pyproject.toml} + * and {@code requirements.txt} files. Searches all dependency scopes. * When uv is available, the uv.lock file is regenerated to reflect the change. */ @EqualsAndHashCode(callSuper = false) @@ -64,8 +65,9 @@ public String getInstanceNameSuffix() { @Override public String getDescription() { - return "Change a dependency to a different package in `pyproject.toml`. " + - "Searches all dependency arrays. When `uv` is available, the `uv.lock` file is regenerated."; + return "Change a dependency to a different package. Supports `pyproject.toml` " + + "and `requirements.txt` files. Searches all dependency scopes. " + + "When `uv` is available, the `uv.lock` file is regenerated."; } static class Accumulator { diff --git a/rewrite-python/src/main/java/org/openrewrite/python/RemoveDependency.java b/rewrite-python/src/main/java/org/openrewrite/python/RemoveDependency.java index 0bc3e23f6b..cb97c1f633 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/RemoveDependency.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/RemoveDependency.java @@ -24,10 +24,13 @@ import org.openrewrite.python.trait.PythonDependencyFile; import org.openrewrite.toml.tree.Toml; -import java.util.*; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; /** - * Remove a dependency from the {@code [project].dependencies} array in pyproject.toml. + * Remove a dependency from a Python project. Supports both {@code pyproject.toml} + * (with scope and group targeting) and {@code requirements.txt} files. * When uv is available, the uv.lock file is regenerated to reflect the change. */ @EqualsAndHashCode(callSuper = false) @@ -76,7 +79,8 @@ public String getInstanceNameSuffix() { @Override public String getDescription() { - return "Remove a dependency from the `[project].dependencies` array in `pyproject.toml`. " + + return "Remove a dependency from a Python project. Supports `pyproject.toml` " + + "(with scope/group targeting) and `requirements.txt` files. " + "When `uv` is available, the `uv.lock` file is regenerated."; } diff --git a/rewrite-python/src/main/java/org/openrewrite/python/UpgradeDependencyVersion.java b/rewrite-python/src/main/java/org/openrewrite/python/UpgradeDependencyVersion.java index 5380f02e74..5bf51660ec 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/UpgradeDependencyVersion.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/UpgradeDependencyVersion.java @@ -23,13 +23,16 @@ import org.openrewrite.python.internal.PythonDependencyExecutionContextView; import org.openrewrite.python.marker.PythonResolutionResult; import org.openrewrite.python.trait.PythonDependencyFile; -import org.openrewrite.toml.TomlIsoVisitor; import org.openrewrite.toml.tree.Toml; -import java.util.*; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; /** - * Upgrade the version constraint for a dependency in {@code [project].dependencies} in pyproject.toml. + * Upgrade the version constraint for a dependency. Supports both {@code pyproject.toml} + * (with scope and group targeting) and {@code requirements.txt} files. * When uv is available, the uv.lock file is regenerated to reflect the change. */ @EqualsAndHashCode(callSuper = false) @@ -83,7 +86,8 @@ public String getInstanceNameSuffix() { @Override public String getDescription() { - return "Upgrade the version constraint for a dependency in `[project].dependencies` in `pyproject.toml`. " + + return "Upgrade the version constraint for a dependency. Supports `pyproject.toml` " + + "(with scope/group targeting) and `requirements.txt` files. " + "When `uv` is available, the `uv.lock` file is regenerated."; } @@ -169,23 +173,4 @@ public TreeVisitor getVisitor(Accumulator acc) { } }; } - - static @Nullable String extractExtras(String pep508Spec) { - int start = pep508Spec.indexOf('['); - int end = pep508Spec.indexOf(']'); - if (start >= 0 && end > start) { - return pep508Spec.substring(start + 1, end); - } - return null; - } - - static @Nullable String extractMarker(String pep508Spec) { - int idx = pep508Spec.indexOf(';'); - if (idx >= 0) { - String marker = pep508Spec.substring(idx + 1).trim(); - return marker.isEmpty() ? null : marker; - } - return null; - } - } diff --git a/rewrite-python/src/main/java/org/openrewrite/python/UpgradeTransitiveDependencyVersion.java b/rewrite-python/src/main/java/org/openrewrite/python/UpgradeTransitiveDependencyVersion.java index 2a6ad1427c..6d617febef 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/UpgradeTransitiveDependencyVersion.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/UpgradeTransitiveDependencyVersion.java @@ -25,16 +25,17 @@ import org.openrewrite.python.trait.PythonDependencyFile; import org.openrewrite.toml.tree.Toml; -import java.util.*; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; /** - * Pin a transitive dependency version by adding or upgrading a constraint in the - * appropriate tool-specific section. The strategy depends on the detected package manager: - *

    - *
  • uv: uses {@code [tool.uv].constraint-dependencies}
  • - *
  • PDM: uses {@code [tool.pdm.overrides]}
  • - *
  • Other/unknown: adds as a direct dependency in {@code [project].dependencies}
  • - *
+ * Pin a transitive dependency version using the strategy appropriate for the file type + * and package manager. For {@code pyproject.toml}: uv uses + * {@code [tool.uv].constraint-dependencies}, PDM uses {@code [tool.pdm.overrides]}, + * and other managers add a direct dependency. For {@code requirements.txt}: appends + * the dependency. When uv is available, the uv.lock file is regenerated. */ @EqualsAndHashCode(callSuper = false) @Value @@ -62,9 +63,10 @@ public String getInstanceNameSuffix() { @Override public String getDescription() { - return "Pin a transitive dependency version using the appropriate strategy for the " + - "detected package manager: uv uses `[tool.uv].constraint-dependencies`, " + - "PDM uses `[tool.pdm.overrides]`, and other managers add a direct dependency."; + return "Pin a transitive dependency version using the strategy appropriate for the file type " + + "and package manager. For `pyproject.toml`: uv uses `[tool.uv].constraint-dependencies`, " + + "PDM uses `[tool.pdm.overrides]`, and other managers add a direct dependency. " + + "For `requirements.txt`: appends the dependency."; } static class Accumulator { @@ -133,7 +135,7 @@ public TreeVisitor getVisitor(Accumulator acc) { String normalizedName = PythonResolutionResult.normalizeName(packageName); Map pins = Collections.singletonMap(normalizedName, version); PythonDependencyFile updated = trait.withPinnedTransitiveDependencies(pins); - SourceFile result = (SourceFile) updated.getTree(); + SourceFile result = updated.getTree(); if (result != tree) { if (result instanceof Toml.Document) { return PyProjectHelper.regenerateLockAndRefreshMarker((Toml.Document) result, ctx); diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java index e7703f7a0f..c9644c1a9f 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/RequirementsFile.java @@ -17,16 +17,45 @@ import org.openrewrite.trait.SimpleTraitMatcher; import java.util.*; +import java.util.regex.Pattern; @Value public class RequirementsFile implements PythonDependencyFile { private static final RequirementsTxtParser PARSER = new RequirementsTxtParser(); + private static final Pattern SCOPE_PATTERN = Pattern.compile("requirements(?:-([\\w-]+))?\\.(?:txt|in)"); Cursor cursor; PythonResolutionResult marker; + /** + * Check whether this file matches the given scope. + *
    + *
  • {@code null} → matches all requirements files
  • + *
  • {@code ""} (empty) → matches only {@code requirements.txt} / {@code requirements.in}
  • + *
  • {@code "dev"} → matches only {@code requirements-dev.txt} / {@code requirements-dev.in}
  • + *
+ */ + private boolean matchesScope(@Nullable String scope) { + if (scope == null) { + return true; + } + String filename = getTree().getSourcePath().getFileName().toString(); + java.util.regex.Matcher m = SCOPE_PATTERN.matcher(filename); + if (!m.matches()) { + return false; + } + String fileSuffix = m.group(1); // null for requirements.txt, "dev" for requirements-dev.txt + if (scope.isEmpty()) { + return fileSuffix == null; + } + return scope.equals(fileSuffix); + } + @Override public RequirementsFile withUpgradedVersions(Map upgrades, @Nullable String scope, @Nullable String groupName) { + if (!matchesScope(scope)) { + return this; + } PlainText pt = (PlainText) getTree(); String text = pt.getText(); String[] lines = text.split("\n", -1); @@ -74,6 +103,9 @@ public RequirementsFile withUpgradedVersions(Map upgrades, @Null @Override public RequirementsFile withAddedDependencies(Map additions, @Nullable String scope, @Nullable String groupName) { + if (!matchesScope(scope)) { + return this; + } PlainText pt = (PlainText) getTree(); String text = pt.getText(); String[] lines = text.split("\n", -1); @@ -108,6 +140,9 @@ public RequirementsFile withAddedDependencies(Map additions, @Nu @Override public RequirementsFile withRemovedDependencies(Set packageNames, @Nullable String scope, @Nullable String groupName) { + if (!matchesScope(scope)) { + return this; + } Set normalizedNames = new HashSet<>(); for (String name : packageNames) { normalizedNames.add(PythonResolutionResult.normalizeName(name)); From 40a0f9a0960fd0127023ec64376a7c6a43bfb025 Mon Sep 17 00:00:00 2001 From: Jente Sondervorst Date: Fri, 3 Apr 2026 20:28:46 +0200 Subject: [PATCH 07/10] Added tests for requirements.txt files --- .../openrewrite/python/AddDependencyTest.java | 19 ++++ .../python/ChangeDependencyTest.java | 32 ++++++- .../python/RemoveDependencyTest.java | 22 ++++- .../python/UpgradeDependencyVersionTest.java | 22 ++++- .../trait/PythonDependencyFileTest.java | 90 +++++++++++++++++++ 5 files changed, 180 insertions(+), 5 deletions(-) diff --git a/rewrite-python/src/test/java/org/openrewrite/python/AddDependencyTest.java b/rewrite-python/src/test/java/org/openrewrite/python/AddDependencyTest.java index 166746d851..478183a5f9 100644 --- a/rewrite-python/src/test/java/org/openrewrite/python/AddDependencyTest.java +++ b/rewrite-python/src/test/java/org/openrewrite/python/AddDependencyTest.java @@ -341,6 +341,25 @@ void addDependencyWithBareVersion() { ); } + @Test + void addDependencyToRequirementsTxt() { + rewriteRun( + spec -> spec.recipe(new AddDependency("flask", ">=2.0", null, null)), + requirementsTxt( + "requests>=2.28.0", + "requests>=2.28.0\nflask>=2.0" + ) + ); + } + + @Test + void skipWhenAlreadyPresentInRequirementsTxt() { + rewriteRun( + spec -> spec.recipe(new AddDependency("requests", null, null, null)), + requirementsTxt("requests>=2.28.0") + ); + } + @Test void validateRequiresGroupName() { var recipe = new AddDependency("pytest", null, "project.optional-dependencies", null); diff --git a/rewrite-python/src/test/java/org/openrewrite/python/ChangeDependencyTest.java b/rewrite-python/src/test/java/org/openrewrite/python/ChangeDependencyTest.java index fd142c38d9..b1a141588b 100644 --- a/rewrite-python/src/test/java/org/openrewrite/python/ChangeDependencyTest.java +++ b/rewrite-python/src/test/java/org/openrewrite/python/ChangeDependencyTest.java @@ -18,7 +18,7 @@ import org.junit.jupiter.api.Test; import org.openrewrite.test.RewriteTest; -import static org.openrewrite.python.Assertions.pyproject; +import static org.openrewrite.python.Assertions.*; class ChangeDependencyTest implements RewriteTest { @@ -171,4 +171,34 @@ void renameAcrossScopes() { ) ); } + + @Test + void renamePackageInRequirementsTxt() { + rewriteRun( + spec -> spec.recipe(new ChangeDependency("requests", "httpx", null)), + requirementsTxt( + "requests>=2.28.0\nclick>=8.0", + "httpx>=2.28.0\nclick>=8.0" + ) + ); + } + + @Test + void renameWithNewVersionInRequirementsTxt() { + rewriteRun( + spec -> spec.recipe(new ChangeDependency("requests", "httpx", ">=0.24.0")), + requirementsTxt( + "requests>=2.28.0\nclick>=8.0", + "httpx>=0.24.0\nclick>=8.0" + ) + ); + } + + @Test + void skipWhenNotFoundInRequirementsTxt() { + rewriteRun( + spec -> spec.recipe(new ChangeDependency("flask", "quart", null)), + requirementsTxt("requests>=2.28.0") + ); + } } diff --git a/rewrite-python/src/test/java/org/openrewrite/python/RemoveDependencyTest.java b/rewrite-python/src/test/java/org/openrewrite/python/RemoveDependencyTest.java index 9bfbcdc7ba..3b96ff028e 100644 --- a/rewrite-python/src/test/java/org/openrewrite/python/RemoveDependencyTest.java +++ b/rewrite-python/src/test/java/org/openrewrite/python/RemoveDependencyTest.java @@ -21,8 +21,7 @@ import java.nio.file.Path; -import static org.openrewrite.python.Assertions.pyproject; -import static org.openrewrite.python.Assertions.uv; +import static org.openrewrite.python.Assertions.*; class RemoveDependencyTest implements RewriteTest { @@ -261,4 +260,23 @@ void removeFromDependencyGroup() { ) ); } + + @Test + void removeDependencyFromRequirementsTxt() { + rewriteRun( + spec -> spec.recipe(new RemoveDependency("click", null, null)), + requirementsTxt( + "requests>=2.28.0\nclick>=8.0", + "requests>=2.28.0" + ) + ); + } + + @Test + void skipWhenNotPresentInRequirementsTxt() { + rewriteRun( + spec -> spec.recipe(new RemoveDependency("flask", null, null)), + requirementsTxt("requests>=2.28.0") + ); + } } diff --git a/rewrite-python/src/test/java/org/openrewrite/python/UpgradeDependencyVersionTest.java b/rewrite-python/src/test/java/org/openrewrite/python/UpgradeDependencyVersionTest.java index 0ec4c7b48c..1108478557 100644 --- a/rewrite-python/src/test/java/org/openrewrite/python/UpgradeDependencyVersionTest.java +++ b/rewrite-python/src/test/java/org/openrewrite/python/UpgradeDependencyVersionTest.java @@ -21,8 +21,7 @@ import java.nio.file.Path; -import static org.openrewrite.python.Assertions.pyproject; -import static org.openrewrite.python.Assertions.uv; +import static org.openrewrite.python.Assertions.*; class UpgradeDependencyVersionTest implements RewriteTest { @@ -285,4 +284,23 @@ void changeVersionInDependencyGroup() { ) ); } + + @Test + void changeVersionInRequirementsTxt() { + rewriteRun( + spec -> spec.recipe(new UpgradeDependencyVersion("requests", ">=2.31.0", null, null)), + requirementsTxt( + "requests>=2.28.0\nclick>=8.0", + "requests>=2.31.0\nclick>=8.0" + ) + ); + } + + @Test + void skipWhenNotPresentInRequirementsTxt() { + rewriteRun( + spec -> spec.recipe(new UpgradeDependencyVersion("flask", ">=3.0", null, null)), + requirementsTxt("requests>=2.28.0") + ); + } } diff --git a/rewrite-python/src/test/java/org/openrewrite/python/trait/PythonDependencyFileTest.java b/rewrite-python/src/test/java/org/openrewrite/python/trait/PythonDependencyFileTest.java index 910c37ded7..18aa010735 100644 --- a/rewrite-python/src/test/java/org/openrewrite/python/trait/PythonDependencyFileTest.java +++ b/rewrite-python/src/test/java/org/openrewrite/python/trait/PythonDependencyFileTest.java @@ -658,4 +658,94 @@ void searchMarkersNoOpViaMatcher() { ); } } + + @Nested + class RequirementsScopeFilterTest { + + @Test + void nullScopeMatchesAllFiles() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = createRequirementsTxt("requests>=2.28.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, null, null); + + assertThat(((PlainText) upgraded.getTree()).getText()).contains("requests>=2.31.0"); + } + + @Test + void emptyScopeMatchesRootRequirementsTxt() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + // requirements.txt (no suffix) should match scope="" + PlainText pt = createRequirementsTxt("requests>=2.28.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, "", null); + + assertThat(((PlainText) upgraded.getTree()).getText()).contains("requests>=2.31.0"); + } + + @Test + void emptyScopeDoesNotMatchScopedFile() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = new PlainText( + randomId(), Paths.get("requirements-dev.txt"), + Markers.EMPTY.addIfAbsent(marker), + "UTF-8", false, null, null, "requests>=2.28.0", null + ); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, "", null); + + // scope="" should NOT match requirements-dev.txt + assertThat(upgraded).isSameAs(trait); + } + + @Test + void devScopeMatchesDevFile() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = new PlainText( + randomId(), Paths.get("requirements-dev.txt"), + Markers.EMPTY.addIfAbsent(marker), + "UTF-8", false, null, null, "requests>=2.28.0", null + ); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, "dev", null); + + assertThat(((PlainText) upgraded.getTree()).getText()).contains("requests>=2.31.0"); + } + + @Test + void devScopeDoesNotMatchRootFile() { + ResolvedDependency resolved = new ResolvedDependency("requests", "2.28.0", null, null); + PythonResolutionResult marker = createMarker(Collections.emptyList(), + Collections.singletonList(resolved)); + + PlainText pt = createRequirementsTxt("requests>=2.28.0", marker); + RequirementsFile trait = requirementsTrait(pt, marker); + + Map upgrades = Collections.singletonMap("requests", "2.31.0"); + RequirementsFile upgraded = trait.withUpgradedVersions(upgrades, "dev", null); + + // scope="dev" should NOT match requirements.txt + assertThat(upgraded).isSameAs(trait); + } + } } From 2cb6a834e4d80d67c55999475afdce2a1aa5bb7d Mon Sep 17 00:00:00 2001 From: Jente Sondervorst Date: Fri, 3 Apr 2026 21:17:10 +0200 Subject: [PATCH 08/10] Added Pipfile support --- .../org/openrewrite/python/Assertions.java | 39 ++ .../org/openrewrite/python/PipfileParser.java | 192 ++++++++++ .../openrewrite/python/trait/PipfileFile.java | 347 ++++++++++++++++++ .../python/trait/PythonDependencyFile.java | 12 +- .../openrewrite/python/AddDependencyTest.java | 18 + .../python/ChangeDependencyTest.java | 19 + .../python/RemoveDependencyTest.java | 18 + .../python/UpgradeDependencyVersionTest.java | 19 + ...pgradeTransitiveDependencyVersionTest.java | 34 +- .../python/trait/PipfileFileTest.java | 297 +++++++++++++++ 10 files changed, 991 insertions(+), 4 deletions(-) create mode 100644 rewrite-python/src/main/java/org/openrewrite/python/PipfileParser.java create mode 100644 rewrite-python/src/main/java/org/openrewrite/python/trait/PipfileFile.java create mode 100644 rewrite-python/src/test/java/org/openrewrite/python/trait/PipfileFileTest.java diff --git a/rewrite-python/src/main/java/org/openrewrite/python/Assertions.java b/rewrite-python/src/main/java/org/openrewrite/python/Assertions.java index a785c7d463..2df0a0ab8b 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/Assertions.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/Assertions.java @@ -237,6 +237,45 @@ public static SourceSpecs setupCfg(@Nullable String before, return text; } + public static SourceSpecs pipfile(@Language("toml") @Nullable String before) { + return pipfile(before, s -> { + }); + } + + public static SourceSpecs pipfile(@Language("toml") @Nullable String before, + Consumer> spec) { + SourceSpec toml = new SourceSpec<>( + Toml.Document.class, null, PipfileParser.builder(), before, + SourceSpec.ValidateSource.noop, + ctx -> { + } + ); + toml.path("Pipfile"); + spec.accept(toml); + return toml; + } + + public static SourceSpecs pipfile(@Language("toml") @Nullable String before, + @Language("toml") @Nullable String after) { + return pipfile(before, after, s -> { + }); + } + + public static SourceSpecs pipfile(@Language("toml") @Nullable String before, + @Language("toml") @Nullable String after, + Consumer> spec) { + SourceSpec toml = new SourceSpec<>( + Toml.Document.class, null, PipfileParser.builder(), before, + SourceSpec.ValidateSource.noop, + ctx -> { + } + ); + toml.path("Pipfile"); + toml.after(s -> after); + spec.accept(toml); + return toml; + } + public static SourceSpecs python(@Language("py") @Nullable String before) { return python(before, s -> { }); diff --git a/rewrite-python/src/main/java/org/openrewrite/python/PipfileParser.java b/rewrite-python/src/main/java/org/openrewrite/python/PipfileParser.java new file mode 100644 index 0000000000..5d3e11df39 --- /dev/null +++ b/rewrite-python/src/main/java/org/openrewrite/python/PipfileParser.java @@ -0,0 +1,192 @@ +/* + * Copyright 2026 the original author or authors. + * + * Moderne Proprietary. Only for use by Moderne customers under the terms of a commercial contract. + */ +package org.openrewrite.python; + +import org.jspecify.annotations.Nullable; +import org.openrewrite.ExecutionContext; +import org.openrewrite.Parser; +import org.openrewrite.SourceFile; +import org.openrewrite.python.internal.PyProjectHelper; +import org.openrewrite.python.marker.PythonResolutionResult; +import org.openrewrite.python.marker.PythonResolutionResult.Dependency; +import org.openrewrite.python.marker.PythonResolutionResult.PackageManager; +import org.openrewrite.toml.TomlParser; +import org.openrewrite.toml.tree.Toml; + +import java.nio.file.Path; +import java.util.*; +import java.util.stream.Stream; + +import static org.openrewrite.Tree.randomId; + +/** + * Parser for Pipfile files that delegates to {@link TomlParser} and attaches a + * {@link PythonResolutionResult} marker with dependency metadata. + */ +public class PipfileParser implements Parser { + + private final TomlParser tomlParser = new TomlParser(); + + @Override + public Stream parseInputs(Iterable sources, @Nullable Path relativeTo, ExecutionContext ctx) { + return tomlParser.parseInputs(sources, relativeTo, ctx).map(sf -> { + if (!(sf instanceof Toml.Document)) { + return sf; + } + Toml.Document doc = (Toml.Document) sf; + PythonResolutionResult marker = createMarker(doc); + if (marker == null) { + return sf; + } + return doc.withMarkers(doc.getMarkers().addIfAbsent(marker)); + }); + } + + static @Nullable PythonResolutionResult createMarker(Toml.Document doc) { + Map tables = indexTables(doc); + + Toml.Table packagesTable = tables.get("packages"); + Toml.Table devPackagesTable = tables.get("dev-packages"); + + // A Pipfile should have at least one dependency section + if (packagesTable == null && devPackagesTable == null) { + return null; + } + + List dependencies = parseDependencyTable(packagesTable); + + Map> optionalDependencies = new LinkedHashMap<>(); + List devDeps = parseDependencyTable(devPackagesTable); + if (!devDeps.isEmpty()) { + optionalDependencies.put("dev-packages", devDeps); + } + + Toml.Table requiresTable = tables.get("requires"); + String requiresPython = requiresTable != null ? getStringValue(requiresTable, "python_version") : null; + + return new PythonResolutionResult( + randomId(), + null, + null, + null, + null, + doc.getSourcePath().toString(), + requiresPython, + null, + Collections.emptyList(), + dependencies, + optionalDependencies, + Collections.emptyMap(), + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList(), + PackageManager.Pipenv, + null + ); + } + + private static List parseDependencyTable(Toml.@Nullable Table table) { + if (table == null) { + return Collections.emptyList(); + } + List deps = new ArrayList<>(); + for (Toml value : table.getValues()) { + if (!(value instanceof Toml.KeyValue)) { + continue; + } + Toml.KeyValue kv = (Toml.KeyValue) value; + if (!(kv.getKey() instanceof Toml.Identifier)) { + continue; + } + String name = ((Toml.Identifier) kv.getKey()).getName(); + String versionConstraint = extractVersion(kv.getValue()); + if ("*".equals(versionConstraint)) { + versionConstraint = null; + } + deps.add(new Dependency(name, versionConstraint, null, null, null)); + } + return deps; + } + + private static @Nullable String extractVersion(Toml value) { + if (value instanceof Toml.Literal) { + Object v = ((Toml.Literal) value).getValue(); + return v instanceof String ? (String) v : null; + } + if (value instanceof Toml.Table) { + // Inline table: {version = ">=3.2", ...} + for (Toml inner : ((Toml.Table) value).getValues()) { + if (inner instanceof Toml.KeyValue) { + Toml.KeyValue innerKv = (Toml.KeyValue) inner; + if (innerKv.getKey() instanceof Toml.Identifier && + "version".equals(((Toml.Identifier) innerKv.getKey()).getName())) { + return extractVersion(innerKv.getValue()); + } + } + } + } + return null; + } + + private static @Nullable String getStringValue(Toml.Table table, String key) { + for (Toml value : table.getValues()) { + if (value instanceof Toml.KeyValue) { + Toml.KeyValue kv = (Toml.KeyValue) value; + if (kv.getKey() instanceof Toml.Identifier && + key.equals(((Toml.Identifier) kv.getKey()).getName()) && + kv.getValue() instanceof Toml.Literal) { + Object v = ((Toml.Literal) kv.getValue()).getValue(); + return v instanceof String ? (String) v : null; + } + } + } + return null; + } + + private static Map indexTables(Toml.Document doc) { + Map tables = new LinkedHashMap<>(); + for (Toml value : doc.getValues()) { + if (value instanceof Toml.Table) { + Toml.Table table = (Toml.Table) value; + if (table.getName() != null) { + tables.put(table.getName().getName(), table); + } + } + } + return tables; + } + + @Override + public boolean accept(Path path) { + return "Pipfile".equals(path.getFileName().toString()); + } + + @Override + public Path sourcePathFromSourceText(Path prefix, String sourceCode) { + return prefix.resolve("Pipfile"); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder extends Parser.Builder { + + Builder() { + super(Toml.Document.class); + } + + @Override + public PipfileParser build() { + return new PipfileParser(); + } + + @Override + public String getDslName() { + return "Pipfile"; + } + } +} diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/PipfileFile.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/PipfileFile.java new file mode 100644 index 0000000000..69cdc0d79a --- /dev/null +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/PipfileFile.java @@ -0,0 +1,347 @@ +/* + * Copyright 2026 the original author or authors. + * + * Moderne Proprietary. Only for use by Moderne customers under the terms of a commercial contract. + */ +package org.openrewrite.python.trait; + +import lombok.Value; +import org.jspecify.annotations.Nullable; +import org.openrewrite.Cursor; +import org.openrewrite.ExecutionContext; +import org.openrewrite.Tree; +import org.openrewrite.marker.Markers; +import org.openrewrite.marker.SearchResult; +import org.openrewrite.python.internal.PyProjectHelper; +import org.openrewrite.python.marker.PythonResolutionResult; +import org.openrewrite.toml.TomlIsoVisitor; +import org.openrewrite.toml.tree.Space; +import org.openrewrite.toml.tree.Toml; +import org.openrewrite.toml.tree.TomlRightPadded; +import org.openrewrite.toml.tree.TomlType; +import org.openrewrite.trait.SimpleTraitMatcher; + +import java.util.*; + +/** + * Trait implementation for Pipfile dependency files. + * Pipfile uses key-value tables: {@code [packages]} for production and + * {@code [dev-packages]} for development dependencies. + */ +@Value +public class PipfileFile implements PythonDependencyFile { + + Cursor cursor; + PythonResolutionResult marker; + + @Override + public PipfileFile withUpgradedVersions(Map upgrades, @Nullable String scope, @Nullable String groupName) { + Toml.Document doc = (Toml.Document) getTree(); + Toml.Document result = (Toml.Document) new TomlIsoVisitor>() { + @Override + public Toml.KeyValue visitKeyValue(Toml.KeyValue keyValue, Map u) { + Toml.KeyValue kv = super.visitKeyValue(keyValue, u); + if (!isInsideTargetTable(getCursor(), scope)) { + return kv; + } + if (!(kv.getKey() instanceof Toml.Identifier)) { + return kv; + } + String pkgName = ((Toml.Identifier) kv.getKey()).getName(); + String newVersion = PythonDependencyFile.getByNormalizedName(u, pkgName); + if (newVersion == null) { + return kv; + } + return updateKeyValueVersion(kv, newVersion); + } + }.visitNonNull(doc, upgrades); + if (result != doc) { + return new PipfileFile(new Cursor(cursor.getParentOrThrow(), result), marker); + } + return this; + } + + @Override + public PipfileFile withAddedDependencies(Map additions, @Nullable String scope, @Nullable String groupName) { + String tableName = resolveTableName(scope); + Toml.Document doc = (Toml.Document) getTree(); + Toml.Document original = doc; + for (Map.Entry entry : additions.entrySet()) { + String normalizedName = PythonResolutionResult.normalizeName(entry.getKey()); + if (!hasDependencyInTable(doc, tableName, normalizedName)) { + doc = addToTable(doc, tableName, entry.getKey(), entry.getValue()); + } + } + if (doc != original) { + return new PipfileFile(new Cursor(cursor.getParentOrThrow(), doc), marker); + } + return this; + } + + @Override + public PipfileFile withRemovedDependencies(Set packageNames, @Nullable String scope, @Nullable String groupName) { + Set normalizedNames = new HashSet<>(); + for (String name : packageNames) { + normalizedNames.add(PythonResolutionResult.normalizeName(name)); + } + Toml.Document doc = (Toml.Document) getTree(); + Toml.Document result = (Toml.Document) new TomlIsoVisitor>() { + @Override + public Toml.Table visitTable(Toml.Table table, Set names) { + Toml.Table t = super.visitTable(table, names); + if (!isTargetTable(t, scope)) { + return t; + } + List newValues = new ArrayList<>(); + boolean changed = false; + for (Toml value : t.getValues()) { + if (value instanceof Toml.KeyValue) { + Toml.KeyValue kv = (Toml.KeyValue) value; + if (kv.getKey() instanceof Toml.Identifier) { + String keyName = ((Toml.Identifier) kv.getKey()).getName(); + if (names.contains(PythonResolutionResult.normalizeName(keyName))) { + changed = true; + continue; + } + } + } + newValues.add(value); + } + return changed ? t.withValues(newValues) : t; + } + }.visitNonNull(doc, normalizedNames); + if (result != doc) { + return new PipfileFile(new Cursor(cursor.getParentOrThrow(), result), marker); + } + return this; + } + + @Override + public PipfileFile withChangedDependency(String oldPackageName, String newPackageName, @Nullable String newVersion) { + String normalizedOld = PythonResolutionResult.normalizeName(oldPackageName); + Toml.Document doc = (Toml.Document) getTree(); + Toml.Document result = (Toml.Document) new TomlIsoVisitor() { + @Override + public Toml.KeyValue visitKeyValue(Toml.KeyValue keyValue, Integer p) { + Toml.KeyValue kv = super.visitKeyValue(keyValue, p); + if (!(kv.getKey() instanceof Toml.Identifier)) { + return kv; + } + String keyName = ((Toml.Identifier) kv.getKey()).getName(); + if (!PythonResolutionResult.normalizeName(keyName).equals(normalizedOld)) { + return kv; + } + // Check we're inside [packages] or [dev-packages] + if (!isInsideTargetTable(getCursor(), null)) { + return kv; + } + Toml.Identifier newKey = ((Toml.Identifier) kv.getKey()) + .withName(newPackageName) + .withSource(newPackageName); + kv = kv.getPadding().withKey(kv.getPadding().getKey().withElement(newKey)); + if (newVersion != null) { + kv = updateKeyValueVersion(kv, newVersion); + } + return kv; + } + }.visitNonNull(doc, 0); + if (result != doc) { + return new PipfileFile(new Cursor(cursor.getParentOrThrow(), result), marker); + } + return this; + } + + @Override + public PipfileFile withPinnedTransitiveDependencies(Map pins) { + // Pipfile has no constraint mechanism — add to [packages] + return withAddedDependencies(pins, "packages", null); + } + + @Override + public PipfileFile withDependencySearchMarkers(Map packageMessages, ExecutionContext ctx) { + Toml.Document doc = (Toml.Document) getTree(); + Toml.Document result = (Toml.Document) new TomlIsoVisitor>() { + @Override + public Toml.KeyValue visitKeyValue(Toml.KeyValue keyValue, Map msgs) { + Toml.KeyValue kv = super.visitKeyValue(keyValue, msgs); + if (!isInsideTargetTable(getCursor(), null)) { + return kv; + } + if (!(kv.getKey() instanceof Toml.Identifier)) { + return kv; + } + String pkgName = ((Toml.Identifier) kv.getKey()).getName(); + String message = PythonDependencyFile.getByNormalizedName(msgs, pkgName); + if (message != null) { + return SearchResult.found(kv, message); + } + return kv; + } + }.visitNonNull(doc, packageMessages); + if (result != doc) { + return new PipfileFile(new Cursor(cursor.getParentOrThrow(), result), marker); + } + return this; + } + + // region Helpers + + /** + * Resolve the target table name from the scope parameter. + * {@code null} defaults to "packages". + */ + private static String resolveTableName(@Nullable String scope) { + if (scope == null || scope.isEmpty() || "packages".equals(scope)) { + return "packages"; + } + return scope; + } + + /** + * Check if a table matches the target scope. + * When scope is null, matches both "packages" and "dev-packages". + */ + private static boolean isTargetTable(Toml.Table table, @Nullable String scope) { + if (table.getName() == null) { + return false; + } + String name = table.getName().getName(); + if (scope == null) { + return "packages".equals(name) || "dev-packages".equals(name); + } + return resolveTableName(scope).equals(name); + } + + /** + * Check if the cursor is inside a target Pipfile table. + * When scope is null, matches both "packages" and "dev-packages". + */ + private static boolean isInsideTargetTable(Cursor cursor, @Nullable String scope) { + Cursor c = cursor; + while (c != null) { + Object val = c.getValue(); + if (val instanceof Toml.Table) { + return isTargetTable((Toml.Table) val, scope); + } + c = c.getParent(); + } + return false; + } + + private static boolean hasDependencyInTable(Toml.Document doc, String tableName, String normalizedName) { + for (Toml value : doc.getValues()) { + if (value instanceof Toml.Table) { + Toml.Table table = (Toml.Table) value; + if (table.getName() != null && tableName.equals(table.getName().getName())) { + for (Toml entry : table.getValues()) { + if (entry instanceof Toml.KeyValue) { + Toml.KeyValue kv = (Toml.KeyValue) entry; + if (kv.getKey() instanceof Toml.Identifier && + PythonResolutionResult.normalizeName( + ((Toml.Identifier) kv.getKey()).getName()).equals(normalizedName)) { + return true; + } + } + } + } + } + } + return false; + } + + private static Toml.Document addToTable(Toml.Document doc, String tableName, String packageName, String version) { + String normalizedVersion = PyProjectHelper.normalizeVersionConstraint(version); + return (Toml.Document) new TomlIsoVisitor() { + @Override + public Toml.Table visitTable(Toml.Table table, Integer p) { + Toml.Table t = super.visitTable(table, p); + if (t.getName() == null || !tableName.equals(t.getName().getName())) { + return t; + } + + Toml.Identifier key = new Toml.Identifier( + Tree.randomId(), Space.EMPTY, Markers.EMPTY, packageName, packageName); + Toml.Literal value = new Toml.Literal( + Tree.randomId(), Space.SINGLE_SPACE, Markers.EMPTY, + TomlType.Primitive.String, "\"" + normalizedVersion + "\"", normalizedVersion); + Toml.KeyValue newKv = new Toml.KeyValue( + Tree.randomId(), Space.EMPTY, Markers.EMPTY, + new TomlRightPadded<>(key, Space.SINGLE_SPACE, Markers.EMPTY), + value); + + List values = t.getValues(); + Space entryPrefix = !values.isEmpty() + ? values.get(values.size() - 1).getPrefix() + : Space.format("\n"); + newKv = newKv.withPrefix(entryPrefix); + + List newValues = new ArrayList<>(values); + newValues.add(newKv); + return t.withValues(newValues); + } + }.visitNonNull(doc, 0); + } + + /** + * Update the version in a key-value pair, handling both simple literals + * ({@code requests = ">=2.28.0"}) and inline tables + * ({@code django = {version = ">=3.2", extras = ["postgres"]}}). + */ + private static Toml.KeyValue updateKeyValueVersion(Toml.KeyValue kv, String newVersion) { + String normalizedVersion = PyProjectHelper.normalizeVersionConstraint(newVersion); + if (kv.getValue() instanceof Toml.Literal) { + Toml.Literal literal = (Toml.Literal) kv.getValue(); + if (normalizedVersion.equals(literal.getValue())) { + return kv; + } + return kv.withValue(literal.withSource("\"" + normalizedVersion + "\"").withValue(normalizedVersion)); + } + if (kv.getValue() instanceof Toml.Table) { + // Inline table: update the "version" key inside + Toml.Table inlineTable = (Toml.Table) kv.getValue(); + List newValues = new ArrayList<>(); + boolean changed = false; + for (Toml inner : inlineTable.getValues()) { + if (inner instanceof Toml.KeyValue) { + Toml.KeyValue innerKv = (Toml.KeyValue) inner; + if (innerKv.getKey() instanceof Toml.Identifier && + "version".equals(((Toml.Identifier) innerKv.getKey()).getName()) && + innerKv.getValue() instanceof Toml.Literal) { + Toml.Literal literal = (Toml.Literal) innerKv.getValue(); + if (!normalizedVersion.equals(literal.getValue())) { + newValues.add(innerKv.withValue( + literal.withSource("\"" + normalizedVersion + "\"").withValue(normalizedVersion))); + changed = true; + continue; + } + } + } + newValues.add(inner); + } + if (changed) { + return kv.withValue(inlineTable.withValues(newValues)); + } + } + return kv; + } + + // endregion + + public static class Matcher extends SimpleTraitMatcher { + @Override + protected @Nullable PipfileFile test(Cursor cursor) { + Object value = cursor.getValue(); + if (value instanceof Toml.Document) { + Toml.Document doc = (Toml.Document) value; + if ("Pipfile".equals(doc.getSourcePath().getFileName().toString())) { + PythonResolutionResult marker = doc.getMarkers() + .findFirst(PythonResolutionResult.class).orElse(null); + if (marker != null) { + return new PipfileFile(cursor, marker); + } + } + } + return null; + } + } +} diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java index 782a8399a1..86f0487e12 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java @@ -151,12 +151,20 @@ static PythonResolutionResult updateResolvedVersions( class Matcher extends SimpleTraitMatcher { private final RequirementsFile.Matcher reqMatcher = new RequirementsFile.Matcher(); - private final PyProjectFile.Matcher tomlMatcher = new PyProjectFile.Matcher(); + private final PyProjectFile.Matcher pyprojectMatcher = new PyProjectFile.Matcher(); + private final PipfileFile.Matcher pipfileMatcher = new PipfileFile.Matcher(); @Override protected @Nullable PythonDependencyFile test(Cursor cursor) { PythonDependencyFile r = reqMatcher.test(cursor); - return r != null ? r : tomlMatcher.test(cursor); + if (r != null) { + return r; + } + r = pyprojectMatcher.test(cursor); + if (r != null) { + return r; + } + return pipfileMatcher.test(cursor); } } } diff --git a/rewrite-python/src/test/java/org/openrewrite/python/AddDependencyTest.java b/rewrite-python/src/test/java/org/openrewrite/python/AddDependencyTest.java index 478183a5f9..081f92c248 100644 --- a/rewrite-python/src/test/java/org/openrewrite/python/AddDependencyTest.java +++ b/rewrite-python/src/test/java/org/openrewrite/python/AddDependencyTest.java @@ -360,6 +360,24 @@ void skipWhenAlreadyPresentInRequirementsTxt() { ); } + @Test + void addDependencyToPipfile() { + rewriteRun( + spec -> spec.recipe(new AddDependency("flask", ">=2.0", null, null)), + pipfile( + """ + [packages] + requests = ">=2.28.0" + """, + """ + [packages] + requests = ">=2.28.0" + flask = ">=2.0" + """ + ) + ); + } + @Test void validateRequiresGroupName() { var recipe = new AddDependency("pytest", null, "project.optional-dependencies", null); diff --git a/rewrite-python/src/test/java/org/openrewrite/python/ChangeDependencyTest.java b/rewrite-python/src/test/java/org/openrewrite/python/ChangeDependencyTest.java index b1a141588b..1d000fd899 100644 --- a/rewrite-python/src/test/java/org/openrewrite/python/ChangeDependencyTest.java +++ b/rewrite-python/src/test/java/org/openrewrite/python/ChangeDependencyTest.java @@ -201,4 +201,23 @@ void skipWhenNotFoundInRequirementsTxt() { requirementsTxt("requests>=2.28.0") ); } + + @Test + void renamePackageInPipfile() { + rewriteRun( + spec -> spec.recipe(new ChangeDependency("requests", "httpx", ">=0.24.0")), + pipfile( + """ + [packages] + requests = ">=2.28.0" + click = ">=8.0" + """, + """ + [packages] + httpx = ">=0.24.0" + click = ">=8.0" + """ + ) + ); + } } diff --git a/rewrite-python/src/test/java/org/openrewrite/python/RemoveDependencyTest.java b/rewrite-python/src/test/java/org/openrewrite/python/RemoveDependencyTest.java index 3b96ff028e..25ad1e1534 100644 --- a/rewrite-python/src/test/java/org/openrewrite/python/RemoveDependencyTest.java +++ b/rewrite-python/src/test/java/org/openrewrite/python/RemoveDependencyTest.java @@ -279,4 +279,22 @@ void skipWhenNotPresentInRequirementsTxt() { requirementsTxt("requests>=2.28.0") ); } + + @Test + void removeDependencyFromPipfile() { + rewriteRun( + spec -> spec.recipe(new RemoveDependency("click", null, null)), + pipfile( + """ + [packages] + requests = ">=2.28.0" + click = ">=8.0" + """, + """ + [packages] + requests = ">=2.28.0" + """ + ) + ); + } } diff --git a/rewrite-python/src/test/java/org/openrewrite/python/UpgradeDependencyVersionTest.java b/rewrite-python/src/test/java/org/openrewrite/python/UpgradeDependencyVersionTest.java index 1108478557..641fb2ae5a 100644 --- a/rewrite-python/src/test/java/org/openrewrite/python/UpgradeDependencyVersionTest.java +++ b/rewrite-python/src/test/java/org/openrewrite/python/UpgradeDependencyVersionTest.java @@ -303,4 +303,23 @@ void skipWhenNotPresentInRequirementsTxt() { requirementsTxt("requests>=2.28.0") ); } + + @Test + void changeVersionInPipfile() { + rewriteRun( + spec -> spec.recipe(new UpgradeDependencyVersion("requests", ">=2.31.0", null, null)), + pipfile( + """ + [packages] + requests = ">=2.28.0" + click = ">=8.0" + """, + """ + [packages] + requests = ">=2.31.0" + click = ">=8.0" + """ + ) + ); + } } diff --git a/rewrite-python/src/test/java/org/openrewrite/python/UpgradeTransitiveDependencyVersionTest.java b/rewrite-python/src/test/java/org/openrewrite/python/UpgradeTransitiveDependencyVersionTest.java index b24883eea8..7c13b7c06d 100644 --- a/rewrite-python/src/test/java/org/openrewrite/python/UpgradeTransitiveDependencyVersionTest.java +++ b/rewrite-python/src/test/java/org/openrewrite/python/UpgradeTransitiveDependencyVersionTest.java @@ -21,8 +21,7 @@ import java.nio.file.Path; -import static org.openrewrite.python.Assertions.pyproject; -import static org.openrewrite.python.Assertions.uv; +import static org.openrewrite.python.Assertions.*; class UpgradeTransitiveDependencyVersionTest implements RewriteTest { @@ -494,4 +493,35 @@ void skipFallbackWhenDirectDependency() { ) ); } + + @Test + void addTransitivePinToPipfile() { + rewriteRun( + spec -> spec.recipe(new UpgradeTransitiveDependencyVersion("certifi", ">=2023.7.22")), + pipfile( + """ + [packages] + requests = ">=2.28.0" + """, + """ + [packages] + requests = ">=2.28.0" + certifi = ">=2023.7.22" + """ + ) + ); + } + + @Test + void skipDirectDependencyInPipfile() { + rewriteRun( + spec -> spec.recipe(new UpgradeTransitiveDependencyVersion("requests", ">=2.31.0")), + pipfile( + """ + [packages] + requests = ">=2.28.0" + """ + ) + ); + } } diff --git a/rewrite-python/src/test/java/org/openrewrite/python/trait/PipfileFileTest.java b/rewrite-python/src/test/java/org/openrewrite/python/trait/PipfileFileTest.java new file mode 100644 index 0000000000..95fad0124d --- /dev/null +++ b/rewrite-python/src/test/java/org/openrewrite/python/trait/PipfileFileTest.java @@ -0,0 +1,297 @@ +/* + * Copyright 2026 the original author or authors. + *

+ * Licensed under the Moderne Source Available License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://docs.moderne.io/licensing/moderne-source-available-license + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.python.trait; + +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.openrewrite.*; +import org.openrewrite.marker.Markers; +import org.openrewrite.marker.SearchResult; +import org.openrewrite.python.marker.PythonResolutionResult; +import org.openrewrite.python.marker.PythonResolutionResult.Dependency; +import org.openrewrite.python.marker.PythonResolutionResult.ResolvedDependency; +import org.openrewrite.test.RewriteTest; +import org.openrewrite.toml.TomlParser; +import org.openrewrite.toml.tree.Toml; + +import java.nio.file.Paths; +import java.util.*; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.openrewrite.Tree.randomId; + +class PipfileFileTest implements RewriteTest { + + private static PythonResolutionResult createMarker(List dependencies) { + return new PythonResolutionResult( + randomId(), null, null, null, null, + "Pipfile", null, null, + Collections.emptyList(), dependencies, + Collections.emptyMap(), Collections.emptyMap(), + Collections.emptyList(), Collections.emptyList(), + Collections.emptyList(), PythonResolutionResult.PackageManager.Pipenv, null + ); + } + + private static Toml.Document parsePipfile(String content, PythonResolutionResult marker) { + TomlParser parser = new TomlParser(); + Parser.Input input = Parser.Input.fromString(Paths.get("Pipfile"), content); + List parsed = parser.parseInputs( + Collections.singletonList(input), null, + new InMemoryExecutionContext(Throwable::printStackTrace) + ).collect(Collectors.toList()); + Toml.Document doc = (Toml.Document) parsed.get(0); + return doc.withMarkers(doc.getMarkers().addIfAbsent(marker)); + } + + private static Cursor rootCursor(Object value) { + return new Cursor(new Cursor(null, Cursor.ROOT_VALUE), value); + } + + private static PipfileFile trait(Toml.Document doc, PythonResolutionResult marker) { + return new PipfileFile(rootCursor(doc), marker); + } + + @Nested + class MatcherTest { + @Test + void matchesPipfile() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + Toml.Document doc = parsePipfile("[packages]\nrequests = \">=2.28.0\"", marker); + + PythonDependencyFile.Matcher matcher = new PythonDependencyFile.Matcher(); + PythonDependencyFile result = matcher.test(rootCursor(doc)); + + assertThat(result).isNotNull(); + assertThat(result).isInstanceOf(PipfileFile.class); + } + + @Test + void doesNotMatchWithoutMarker() { + TomlParser parser = new TomlParser(); + Parser.Input input = Parser.Input.fromString(Paths.get("Pipfile"), "[packages]\nrequests = \"*\""); + Toml.Document doc = (Toml.Document) parser.parseInputs( + Collections.singletonList(input), null, + new InMemoryExecutionContext(Throwable::printStackTrace) + ).collect(Collectors.toList()).get(0); + + PipfileFile.Matcher matcher = new PipfileFile.Matcher(); + assertThat(matcher.test(rootCursor(doc))).isNull(); + } + + @Test + void doesNotMatchPyprojectToml() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + TomlParser parser = new TomlParser(); + Parser.Input input = Parser.Input.fromString(Paths.get("pyproject.toml"), "[project]\nname = \"test\""); + Toml.Document doc = (Toml.Document) parser.parseInputs( + Collections.singletonList(input), null, + new InMemoryExecutionContext(Throwable::printStackTrace) + ).collect(Collectors.toList()).get(0); + doc = doc.withMarkers(doc.getMarkers().addIfAbsent(marker)); + + PipfileFile.Matcher matcher = new PipfileFile.Matcher(); + assertThat(matcher.test(rootCursor(doc))).isNull(); + } + } + + @Nested + class UpgradeVersionTest { + @Test + void upgradeSimpleVersion() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + Toml.Document doc = parsePipfile("[packages]\nrequests = \">=2.28.0\"", marker); + PipfileFile t = trait(doc, marker); + + PipfileFile upgraded = t.withUpgradedVersions( + Collections.singletonMap("requests", ">=2.31.0"), null, null); + + String printed = ((Toml.Document) upgraded.getTree()).printAll(); + assertThat(printed).contains("requests = \">=2.31.0\""); + } + + @Test + void upgradeInDevPackages() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + Toml.Document doc = parsePipfile( + "[packages]\nflask = \"*\"\n\n[dev-packages]\npytest = \">=7.0\"", marker); + PipfileFile t = trait(doc, marker); + + PipfileFile upgraded = t.withUpgradedVersions( + Collections.singletonMap("pytest", ">=8.0"), "dev-packages", null); + + String printed = ((Toml.Document) upgraded.getTree()).printAll(); + assertThat(printed).contains("pytest = \">=8.0\""); + assertThat(printed).contains("flask = \"*\""); + } + + @Test + void noOpWhenNotFound() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + Toml.Document doc = parsePipfile("[packages]\nrequests = \">=2.28.0\"", marker); + PipfileFile t = trait(doc, marker); + + PipfileFile upgraded = t.withUpgradedVersions( + Collections.singletonMap("nonexistent", ">=1.0"), null, null); + + assertThat(upgraded).isSameAs(t); + } + } + + @Nested + class AddDependencyTest { + @Test + void addToPackages() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + Toml.Document doc = parsePipfile("[packages]\nrequests = \">=2.28.0\"", marker); + PipfileFile t = trait(doc, marker); + + PipfileFile added = t.withAddedDependencies( + Collections.singletonMap("flask", ">=2.0"), "packages", null); + + String printed = ((Toml.Document) added.getTree()).printAll(); + assertThat(printed).contains("flask = \">=2.0\""); + assertThat(printed).contains("requests = \">=2.28.0\""); + } + + @Test + void addToDevPackages() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + Toml.Document doc = parsePipfile( + "[packages]\nrequests = \"*\"\n\n[dev-packages]\npytest = \">=7.0\"", marker); + PipfileFile t = trait(doc, marker); + + PipfileFile added = t.withAddedDependencies( + Collections.singletonMap("mypy", ">=1.0"), "dev-packages", null); + + String printed = ((Toml.Document) added.getTree()).printAll(); + assertThat(printed).contains("mypy = \">=1.0\""); + } + + @Test + void noOpWhenAlreadyPresent() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + Toml.Document doc = parsePipfile("[packages]\nrequests = \">=2.28.0\"", marker); + PipfileFile t = trait(doc, marker); + + PipfileFile added = t.withAddedDependencies( + Collections.singletonMap("requests", ">=2.31.0"), "packages", null); + + assertThat(added).isSameAs(t); + } + } + + @Nested + class RemoveDependencyTest { + @Test + void removeFromPackages() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + Toml.Document doc = parsePipfile("[packages]\nrequests = \">=2.28.0\"\nflask = \"*\"", marker); + PipfileFile t = trait(doc, marker); + + PipfileFile removed = t.withRemovedDependencies( + Collections.singleton("flask"), "packages", null); + + String printed = ((Toml.Document) removed.getTree()).printAll(); + assertThat(printed).contains("requests = \">=2.28.0\""); + assertThat(printed).doesNotContain("flask"); + } + + @Test + void noOpWhenNotFound() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + Toml.Document doc = parsePipfile("[packages]\nrequests = \">=2.28.0\"", marker); + PipfileFile t = trait(doc, marker); + + PipfileFile removed = t.withRemovedDependencies( + Collections.singleton("nonexistent"), "packages", null); + + assertThat(removed).isSameAs(t); + } + } + + @Nested + class ChangeDependencyTest { + @Test + void renamePackage() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + Toml.Document doc = parsePipfile("[packages]\nrequests = \">=2.28.0\"", marker); + PipfileFile t = trait(doc, marker); + + PipfileFile changed = t.withChangedDependency("requests", "httpx", null); + + String printed = ((Toml.Document) changed.getTree()).printAll(); + assertThat(printed).contains("httpx = \">=2.28.0\""); + assertThat(printed).doesNotContain("requests"); + } + + @Test + void renameWithNewVersion() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + Toml.Document doc = parsePipfile("[packages]\nrequests = \">=2.28.0\"", marker); + PipfileFile t = trait(doc, marker); + + PipfileFile changed = t.withChangedDependency("requests", "httpx", ">=0.24.0"); + + String printed = ((Toml.Document) changed.getTree()).printAll(); + assertThat(printed).contains("httpx = \">=0.24.0\""); + } + } + + @Nested + class SearchMarkersTest { + @Test + void markVulnerableDependency() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + Toml.Document doc = parsePipfile( + "[packages]\nrequests = \">=2.28.0\"\nflask = \"*\"", marker); + PipfileFile t = trait(doc, marker); + + ExecutionContext ctx = new InMemoryExecutionContext(Throwable::printStackTrace); + PipfileFile marked = t.withDependencySearchMarkers( + Collections.singletonMap("requests", "CVE-2023-1234"), ctx); + + Toml.Document result = (Toml.Document) marked.getTree(); + boolean[] found = {false}; + new org.openrewrite.toml.TomlVisitor() { + @Override + public Toml visitKeyValue(Toml.KeyValue keyValue, Integer p) { + if (keyValue.getKey() instanceof Toml.Identifier && + "requests".equals(((Toml.Identifier) keyValue.getKey()).getName()) && + keyValue.getMarkers().findFirst(SearchResult.class).isPresent()) { + found[0] = true; + } + return keyValue; + } + }.visit(result, 0); + assertThat(found[0]).as("requests should have SearchResult marker").isTrue(); + } + + @Test + void noOpWhenNoMatch() { + PythonResolutionResult marker = createMarker(Collections.emptyList()); + Toml.Document doc = parsePipfile("[packages]\nrequests = \">=2.28.0\"", marker); + PipfileFile t = trait(doc, marker); + + ExecutionContext ctx = new InMemoryExecutionContext(Throwable::printStackTrace); + PipfileFile marked = t.withDependencySearchMarkers( + Collections.singletonMap("nonexistent", "CVE-2023-9999"), ctx); + + assertThat(marked).isSameAs(t); + } + } +} From 2528b41df62659dc22b57e92a26ca6b2f1b56e40 Mon Sep 17 00:00:00 2001 From: Jente Sondervorst Date: Fri, 3 Apr 2026 21:32:21 +0200 Subject: [PATCH 09/10] further optimizations --- .../java/org/openrewrite/python/AddDependency.java | 8 ++------ .../org/openrewrite/python/ChangeDependency.java | 8 ++------ .../org/openrewrite/python/RemoveDependency.java | 8 ++------ .../openrewrite/python/UpgradeDependencyVersion.java | 8 ++------ .../python/UpgradeTransitiveDependencyVersion.java | 8 ++------ .../org/openrewrite/python/trait/PipfileFile.java | 12 ++++++++++-- .../org/openrewrite/python/trait/PyProjectFile.java | 7 +++++++ .../python/trait/PythonDependencyFile.java | 12 ++++++++++++ 8 files changed, 39 insertions(+), 32 deletions(-) diff --git a/rewrite-python/src/main/java/org/openrewrite/python/AddDependency.java b/rewrite-python/src/main/java/org/openrewrite/python/AddDependency.java index fd4f10bb1b..a2e73f537e 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/AddDependency.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/AddDependency.java @@ -150,12 +150,8 @@ public TreeVisitor getVisitor(Accumulator acc) { String ver = version != null ? version : ""; Map additions = Collections.singletonMap(packageName, ver); PythonDependencyFile updated = trait.withAddedDependencies(additions, scope, groupName); - SourceFile result = (SourceFile) updated.getTree(); - if (result != tree) { - if (result instanceof Toml.Document) { - return PyProjectHelper.regenerateLockAndRefreshMarker((Toml.Document) result, ctx); - } - return result; + if (updated.getTree() != tree) { + return updated.afterModification(ctx); } } } diff --git a/rewrite-python/src/main/java/org/openrewrite/python/ChangeDependency.java b/rewrite-python/src/main/java/org/openrewrite/python/ChangeDependency.java index 03e70e2829..3259e0f9fd 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/ChangeDependency.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/ChangeDependency.java @@ -123,12 +123,8 @@ public TreeVisitor getVisitor(Accumulator acc) { PythonDependencyFile trait = new PythonDependencyFile.Matcher().get(getCursor()).orElse(null); if (trait != null) { PythonDependencyFile updated = trait.withChangedDependency(oldPackageName, newPackageName, newVersion); - SourceFile result = (SourceFile) updated.getTree(); - if (result != tree) { - if (result instanceof Toml.Document) { - return PyProjectHelper.regenerateLockAndRefreshMarker((Toml.Document) result, ctx); - } - return result; + if (updated.getTree() != tree) { + return updated.afterModification(ctx); } } } diff --git a/rewrite-python/src/main/java/org/openrewrite/python/RemoveDependency.java b/rewrite-python/src/main/java/org/openrewrite/python/RemoveDependency.java index cb97c1f633..997f37e6ef 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/RemoveDependency.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/RemoveDependency.java @@ -139,12 +139,8 @@ public TreeVisitor getVisitor(Accumulator acc) { if (trait != null) { PythonDependencyFile updated = trait.withRemovedDependencies( Collections.singleton(packageName), scope, groupName); - SourceFile result = (SourceFile) updated.getTree(); - if (result != tree) { - if (result instanceof Toml.Document) { - return PyProjectHelper.regenerateLockAndRefreshMarker((Toml.Document) result, ctx); - } - return result; + if (updated.getTree() != tree) { + return updated.afterModification(ctx); } } } diff --git a/rewrite-python/src/main/java/org/openrewrite/python/UpgradeDependencyVersion.java b/rewrite-python/src/main/java/org/openrewrite/python/UpgradeDependencyVersion.java index 5bf51660ec..4492e8e2a3 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/UpgradeDependencyVersion.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/UpgradeDependencyVersion.java @@ -152,12 +152,8 @@ public TreeVisitor getVisitor(Accumulator acc) { Map upgrades = Collections.singletonMap( PythonResolutionResult.normalizeName(packageName), newVersion); PythonDependencyFile updated = trait.withUpgradedVersions(upgrades, scope, groupName); - SourceFile result = (SourceFile) updated.getTree(); - if (result != tree) { - if (result instanceof Toml.Document) { - return PyProjectHelper.regenerateLockAndRefreshMarker((Toml.Document) result, ctx); - } - return result; + if (updated.getTree() != tree) { + return updated.afterModification(ctx); } } } diff --git a/rewrite-python/src/main/java/org/openrewrite/python/UpgradeTransitiveDependencyVersion.java b/rewrite-python/src/main/java/org/openrewrite/python/UpgradeTransitiveDependencyVersion.java index 6d617febef..f8adfd59b4 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/UpgradeTransitiveDependencyVersion.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/UpgradeTransitiveDependencyVersion.java @@ -135,12 +135,8 @@ public TreeVisitor getVisitor(Accumulator acc) { String normalizedName = PythonResolutionResult.normalizeName(packageName); Map pins = Collections.singletonMap(normalizedName, version); PythonDependencyFile updated = trait.withPinnedTransitiveDependencies(pins); - SourceFile result = updated.getTree(); - if (result != tree) { - if (result instanceof Toml.Document) { - return PyProjectHelper.regenerateLockAndRefreshMarker((Toml.Document) result, ctx); - } - return result; + if (updated.getTree() != tree) { + return updated.afterModification(ctx); } } } diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/PipfileFile.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/PipfileFile.java index 69cdc0d79a..8255c5507d 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/trait/PipfileFile.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/PipfileFile.java @@ -56,7 +56,11 @@ public Toml.KeyValue visitKeyValue(Toml.KeyValue keyValue, Map u } }.visitNonNull(doc, upgrades); if (result != doc) { - return new PipfileFile(new Cursor(cursor.getParentOrThrow(), result), marker); + PythonResolutionResult updatedMarker = PythonDependencyFile.updateResolvedVersions(marker, upgrades); + result = result.withMarkers(result.getMarkers() + .removeByType(PythonResolutionResult.class) + .addIfAbsent(updatedMarker)); + return new PipfileFile(new Cursor(cursor.getParentOrThrow(), result), updatedMarker); } return this; } @@ -73,7 +77,11 @@ public PipfileFile withAddedDependencies(Map additions, @Nullabl } } if (doc != original) { - return new PipfileFile(new Cursor(cursor.getParentOrThrow(), doc), marker); + PythonResolutionResult updatedMarker = PythonDependencyFile.updateResolvedVersions(marker, additions); + doc = doc.withMarkers(doc.getMarkers() + .removeByType(PythonResolutionResult.class) + .addIfAbsent(updatedMarker)); + return new PipfileFile(new Cursor(cursor.getParentOrThrow(), doc), updatedMarker); } return this; } diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java index 89bf10ef71..32323c2e65 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java @@ -9,6 +9,7 @@ import org.jspecify.annotations.Nullable; import org.openrewrite.Cursor; import org.openrewrite.ExecutionContext; +import org.openrewrite.SourceFile; import org.openrewrite.Tree; import org.openrewrite.marker.Markers; import org.openrewrite.marker.SearchResult; @@ -407,6 +408,12 @@ public Toml.Literal visitLiteral(Toml.Literal literal, Map msgs) return this; } + @Override + public SourceFile afterModification(ExecutionContext ctx) { + Toml.Document doc = (Toml.Document) getTree(); + return PyProjectHelper.regenerateLockAndRefreshMarker(doc, ctx); + } + private static boolean isInsideTargetArray(Cursor cursor, @Nullable String scope, @Nullable String groupName) { Cursor c = cursor; while (c != null) { diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java index 86f0487e12..4365661e2e 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/PythonDependencyFile.java @@ -82,6 +82,18 @@ public interface PythonDependencyFile extends Trait { */ PythonDependencyFile withDependencySearchMarkers(Map packageMessages, ExecutionContext ctx); + /** + * Post-process the modified source file, e.g. regenerate lock files. + * Called by recipes after a trait method modifies the tree. + * The default implementation returns the tree unchanged. + * + * @param ctx the execution context + * @return the post-processed source file + */ + default SourceFile afterModification(ExecutionContext ctx) { + return getTree(); + } + /** * Rewrite a PEP 508 dependency spec with a new version constraint. * Preserves extras and environment markers. The version is normalized From 8ee97cf3ed3d1fa0f3c8d75db500baedcc4b1922 Mon Sep 17 00:00:00 2001 From: Jente Sondervorst Date: Fri, 3 Apr 2026 22:28:14 +0200 Subject: [PATCH 10/10] further optimizations --- .../main/java/org/openrewrite/python/AddDependency.java | 6 +++--- .../java/org/openrewrite/python/ChangeDependency.java | 8 ++++---- .../java/org/openrewrite/python/RemoveDependency.java | 6 +++--- .../org/openrewrite/python/UpgradeDependencyVersion.java | 6 +++--- .../python/UpgradeTransitiveDependencyVersion.java | 7 ++++--- .../java/org/openrewrite/python/trait/PyProjectFile.java | 4 ++-- 6 files changed, 19 insertions(+), 18 deletions(-) diff --git a/rewrite-python/src/main/java/org/openrewrite/python/AddDependency.java b/rewrite-python/src/main/java/org/openrewrite/python/AddDependency.java index a2e73f537e..1cf3472b23 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/AddDependency.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/AddDependency.java @@ -30,8 +30,8 @@ import java.util.Set; /** - * Add a dependency to a Python project. Supports both {@code pyproject.toml} - * (with scope and group targeting) and {@code requirements.txt} files. + * Add a dependency to a Python project. Supports {@code pyproject.toml} + * (with scope and group targeting), {@code requirements.txt}, and {@code Pipfile}. * When uv is available, the uv.lock file is regenerated to reflect the change. */ @EqualsAndHashCode(callSuper = false) @@ -90,7 +90,7 @@ public String getInstanceNameSuffix() { @Override public String getDescription() { return "Add a dependency to a Python project. Supports `pyproject.toml` " + - "(with scope/group targeting) and `requirements.txt` files. " + + "(with scope/group targeting), `requirements.txt`, and `Pipfile`. " + "When `uv` is available, the `uv.lock` file is regenerated."; } diff --git a/rewrite-python/src/main/java/org/openrewrite/python/ChangeDependency.java b/rewrite-python/src/main/java/org/openrewrite/python/ChangeDependency.java index 3259e0f9fd..a9f8f9952e 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/ChangeDependency.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/ChangeDependency.java @@ -28,8 +28,8 @@ import java.util.Set; /** - * Change a dependency to a different package. Supports both {@code pyproject.toml} - * and {@code requirements.txt} files. Searches all dependency scopes. + * Change a dependency to a different package. Supports {@code pyproject.toml}, + * {@code requirements.txt}, and {@code Pipfile}. Searches all dependency scopes. * When uv is available, the uv.lock file is regenerated to reflect the change. */ @EqualsAndHashCode(callSuper = false) @@ -65,8 +65,8 @@ public String getInstanceNameSuffix() { @Override public String getDescription() { - return "Change a dependency to a different package. Supports `pyproject.toml` " + - "and `requirements.txt` files. Searches all dependency scopes. " + + return "Change a dependency to a different package. Supports `pyproject.toml`, " + + "`requirements.txt`, and `Pipfile`. Searches all dependency scopes. " + "When `uv` is available, the `uv.lock` file is regenerated."; } diff --git a/rewrite-python/src/main/java/org/openrewrite/python/RemoveDependency.java b/rewrite-python/src/main/java/org/openrewrite/python/RemoveDependency.java index 997f37e6ef..8fd8e46f1b 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/RemoveDependency.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/RemoveDependency.java @@ -29,8 +29,8 @@ import java.util.Set; /** - * Remove a dependency from a Python project. Supports both {@code pyproject.toml} - * (with scope and group targeting) and {@code requirements.txt} files. + * Remove a dependency from a Python project. Supports {@code pyproject.toml} + * (with scope and group targeting), {@code requirements.txt}, and {@code Pipfile}. * When uv is available, the uv.lock file is regenerated to reflect the change. */ @EqualsAndHashCode(callSuper = false) @@ -80,7 +80,7 @@ public String getInstanceNameSuffix() { @Override public String getDescription() { return "Remove a dependency from a Python project. Supports `pyproject.toml` " + - "(with scope/group targeting) and `requirements.txt` files. " + + "(with scope/group targeting), `requirements.txt`, and `Pipfile`. " + "When `uv` is available, the `uv.lock` file is regenerated."; } diff --git a/rewrite-python/src/main/java/org/openrewrite/python/UpgradeDependencyVersion.java b/rewrite-python/src/main/java/org/openrewrite/python/UpgradeDependencyVersion.java index 4492e8e2a3..7bec9b8c1e 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/UpgradeDependencyVersion.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/UpgradeDependencyVersion.java @@ -31,8 +31,8 @@ import java.util.Set; /** - * Upgrade the version constraint for a dependency. Supports both {@code pyproject.toml} - * (with scope and group targeting) and {@code requirements.txt} files. + * Upgrade the version constraint for a dependency. Supports {@code pyproject.toml} + * (with scope and group targeting), {@code requirements.txt}, and {@code Pipfile}. * When uv is available, the uv.lock file is regenerated to reflect the change. */ @EqualsAndHashCode(callSuper = false) @@ -87,7 +87,7 @@ public String getInstanceNameSuffix() { @Override public String getDescription() { return "Upgrade the version constraint for a dependency. Supports `pyproject.toml` " + - "(with scope/group targeting) and `requirements.txt` files. " + + "(with scope/group targeting), `requirements.txt`, and `Pipfile`. " + "When `uv` is available, the `uv.lock` file is regenerated."; } diff --git a/rewrite-python/src/main/java/org/openrewrite/python/UpgradeTransitiveDependencyVersion.java b/rewrite-python/src/main/java/org/openrewrite/python/UpgradeTransitiveDependencyVersion.java index f8adfd59b4..3448ff467f 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/UpgradeTransitiveDependencyVersion.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/UpgradeTransitiveDependencyVersion.java @@ -34,8 +34,9 @@ * Pin a transitive dependency version using the strategy appropriate for the file type * and package manager. For {@code pyproject.toml}: uv uses * {@code [tool.uv].constraint-dependencies}, PDM uses {@code [tool.pdm.overrides]}, - * and other managers add a direct dependency. For {@code requirements.txt}: appends - * the dependency. When uv is available, the uv.lock file is regenerated. + * and other managers add a direct dependency. For {@code requirements.txt} and + * {@code Pipfile}: appends the dependency. When uv is available, the uv.lock file + * is regenerated. */ @EqualsAndHashCode(callSuper = false) @Value @@ -66,7 +67,7 @@ public String getDescription() { return "Pin a transitive dependency version using the strategy appropriate for the file type " + "and package manager. For `pyproject.toml`: uv uses `[tool.uv].constraint-dependencies`, " + "PDM uses `[tool.pdm.overrides]`, and other managers add a direct dependency. " + - "For `requirements.txt`: appends the dependency."; + "For `requirements.txt` and `Pipfile`: appends the dependency."; } static class Accumulator { diff --git a/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java b/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java index 32323c2e65..ec4ec21ad0 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/trait/PyProjectFile.java @@ -410,8 +410,8 @@ public Toml.Literal visitLiteral(Toml.Literal literal, Map msgs) @Override public SourceFile afterModification(ExecutionContext ctx) { - Toml.Document doc = (Toml.Document) getTree(); - return PyProjectHelper.regenerateLockAndRefreshMarker(doc, ctx); + // regenerateLockAndRefreshMarker already guards against missing lock content internally + return PyProjectHelper.regenerateLockAndRefreshMarker((Toml.Document) getTree(), ctx); } private static boolean isInsideTargetArray(Cursor cursor, @Nullable String scope, @Nullable String groupName) {