diff --git a/passes/pmgen/xilinx_dsp.pmg b/passes/pmgen/xilinx_dsp.pmg
index 7be841ff3..3aab807bd 100644
--- a/passes/pmgen/xilinx_dsp.pmg
+++ b/passes/pmgen/xilinx_dsp.pmg
@@ -1,5 +1,6 @@
 pattern xilinx_dsp
 
+state <std::function<SigSpec(const SigSpec&, bool)>> unextend
 state <SigBit> clock
 state <SigSpec> sigA sigffAmuxY sigB sigffBmuxY sigC sigD sigffDmuxY sigM sigP
 state <IdString> postAddAB postAddMuxAB
@@ -10,29 +11,26 @@ match dsp
 	select dsp->type.in(\DSP48E1)
 endmatch
 
-code sigA sigffAmuxY sigB sigffBmuxY sigD sigM
-	sigA = port(dsp, \A);
-	int i;
-	for (i = GetSize(sigA)-1; i > 0; i--)
-		if (sigA[i] != sigA[i-1])
-			break;
-	// Do not remove non-const sign bit
-	if (sigA[i].wire)
-		++i;
-	sigA.remove(i, GetSize(sigA)-i);
-	sigB = port(dsp, \B);
-	for (i = GetSize(sigB)-1; i > 0; i--)
-		if (sigB[i] != sigB[i-1])
-			break;
-	// Do not remove non-const sign bit
-	if (sigB[i].wire)
-		++i;
-	sigB.remove(i, GetSize(sigB)-i);
+code unextend sigA sigffAmuxY sigB sigffBmuxY sigC sigD sigffDmuxY sigM
+	unextend = [](const SigSpec &sig, bool keep_sign) {
+		int i;
+		for (i = GetSize(sig)-1; i > 0; i--)
+			if (sig[i] != sig[i-1])
+				break;
+		// Do not remove non-const sign bit
+		if (!keep_sign && sig[i].wire)
+			++i;
+		return sig.extract(0, i);
+	};
+	sigA = unextend(port(dsp, \A), false);
+	sigB = unextend(port(dsp, \B), false);
 
+	sigC = dsp->connections_.at(\C, SigSpec());
 	sigD = dsp->connections_.at(\D, SigSpec());
 
 	SigSpec P = port(dsp, \P);
 	// Only care about those bits that are used
+	int i;
 	for (i = 0; i < GetSize(P); i++) {
 		if (nusers(P[i]) <= 1)
 			break;
@@ -44,6 +42,7 @@ code sigA sigffAmuxY sigB sigffBmuxY sigD sigM
 
 	sigffAmuxY = SigSpec();
 	sigffBmuxY = SigSpec();
+	sigffDmuxY = SigSpec();
 endcode
 
 match ffAD