Thursday, November 25, 2010

TopCoder SRM 477 Div 1 Medium (PythTriplets)

Problem Link : http://www.topcoder.com/stat?c=problem_statement&pm=10766&rd=14157

In this problem, given some integers (at most 200) in the range [1, 999999] you have to find maximum number of disjoint pairs (a,b) such that a^2 + b^2 = c^2 for some integer c and a & b are co-prime.

Now this is a classic example of a bipartite matching problem although it may not be evident how to split the numbers in two bipartite sets. Lets consider the different cases possible according to the parity of the numbers in a pair -

  1. Both of them are even: This cannot be a valid pair because numbers in a pair have to be co-prime.
  2. One is even, the other is odd: This is valid. Examples: 3^2 + 4^2 = 5^2, 12^2 + 5^2 = 13^2.
  3. Both of them are odd: This also is not a valid pair. Proof below -
Let a = 2n + 1 and b = 2m + 1
So, x = a^2 + b^2 = (2n+1)^2 + (2m+1)^2 = 4(n^2+m^2) + 4(n+m) + 2 = 2 [ 2(n^2+m^2) + 2(n+m) + 1 ]
Now as the first 2 terms inside the square bracket are even numbers, 3 terms combined makes an odd number. So, x = 2y where y is an odd number. As such it is not possible for x to be a square number.

So, it is clear that we can split the numbers in 2 bipartite sets according to their parity. After that all you have to do is implement a bipartite matching algorithm. The code I've written here is inspired by Igor Naverniouk aka Abednego.

Note that instead of splitting the numbers according to their parity, we could have just matched one copy of the whole set to the original set and then divide the number of matchings by two to get the actual answer.

Source code below - 

/*
Problem Name : PythTriplets
Problem Number : TopCoder SRM 477 Div 1 Medium
Link : http://www.topcoder.com/stat?c=problem_statement&pm=10766&rd=14157
Problem Type : Bipartite Matching
Difficulty : 5.5 / 10
Interest : 7 / 10
Complexity : O(N^3)
Date : November 25, 2010
*/
#include<string>
#include<vector>
#include<sstream>
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<cmath>
using namespace std;

#define rep(i,n) for(i=0;i<(n);i++)
#define MAXN 200

bool adj[MAXN+2][MAXN+2];
bool flag[MAXN+2];
int matchL[MAXN+2];
int matchR[MAXN+2];
vector<int> odd, even;
int n,m;

struct PythTriplets{
 vector<string> parse(const string& s,const string& delim=" ") {
  vector<string>res;
  string t;
  for(int i=0;i!=s.size();i++) {
   if(delim.find(s[i]) != string::npos) {
    if(!t.empty()) {
     res.push_back(t);
     t="";
    }
   }
   else t+=s[i];
  }
  if(!t.empty()) res.push_back(t);
  return res;
 }

 bool bpm(int u) {
  int i;
  rep(i,m) if(adj[u][i]) {
   if(flag[i]) continue;
   flag[i] = true;
   if(matchR[i] < 0 || bpm(matchR[i]) ) {
    matchL[u] = i;
    matchR[i] = u;
    return true;
   }
  }
  return false;
 }

 int gcd(int a, int b) {
  if(b == 0) return a;
  return gcd(b, a%b);
 }

 bool ok(long long a, long long b) {
  if(gcd(a,b) != 1) return false;
  long long c = a * a + b * b;
  long long s = sqrt((long double)c);
  if(s * s == c) return true;
  return false;
 }

 int findMax(vector <string> stick) {
  int i,j,x;
  string s;
  //parse input
  rep(i,stick.size()) s += stick[i];
  vector<string> vs = parse(s);
  odd.clear(); even.clear();
  rep(i,vs.size()) {
   x = atoi(vs[i].c_str());
   if(x&1) odd.push_back(x);
   else even.push_back(x);
  }

  //create adjacency matrix
  n = odd.size();
  m = even.size();
  memset(adj,0,sizeof(adj));
  rep(i,n) rep(j,m) if(ok(odd[i], even[j])) adj[i][j] = 1;

  //bpm
  memset(matchL,-1,sizeof(matchL));
  memset(matchR,-1,sizeof(matchR));
  int res = 0;
  rep(i,n) {
   memset(flag,0,sizeof(flag));
   if(bpm(i)) res++;
  }
  return res;
 }
};

1 comment: