//===--- TransGCAttrs.cpp - Transformations to ARC mode --------------------===//
//
//                     The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//

#include "Transforms.h"
#include "Internals.h"
#include "clang/AST/ASTContext.h"
#include "clang/Basic/SourceManager.h"
#include "clang/Lex/Lexer.h"
#include "clang/Sema/SemaDiagnostic.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/TinyPtrVector.h"
#include "llvm/Support/SaveAndRestore.h"

using namespace clang;
using namespace arcmt;
using namespace trans;

namespace {

/// \brief Collects all the places where GC attributes __strong/__weak occur.
class GCAttrsCollector : public RecursiveASTVisitor<GCAttrsCollector> {
  MigrationContext &MigrateCtx;
  bool FullyMigratable;
  std::vector<ObjCPropertyDecl *> &AllProps;

  typedef RecursiveASTVisitor<GCAttrsCollector> base;
public:
  GCAttrsCollector(MigrationContext &ctx,
                   std::vector<ObjCPropertyDecl *> &AllProps)
    : MigrateCtx(ctx), FullyMigratable(false),
      AllProps(AllProps) { }

  bool shouldWalkTypesOfTypeLocs() const { return false; }

  bool VisitAttributedTypeLoc(AttributedTypeLoc TL) {
    handleAttr(TL);
    return true;
  }

  bool TraverseDecl(Decl *D) {
    if (!D || D->isImplicit())
      return true;

    SaveAndRestore<bool> Save(FullyMigratable, isMigratable(D));
    
    if (ObjCPropertyDecl *PropD = dyn_cast<ObjCPropertyDecl>(D)) {
      lookForAttribute(PropD, PropD->getTypeSourceInfo());
      AllProps.push_back(PropD);
    } else if (DeclaratorDecl *DD = dyn_cast<DeclaratorDecl>(D)) {
      lookForAttribute(DD, DD->getTypeSourceInfo());
    }
    return base::TraverseDecl(D);
  }

  void lookForAttribute(Decl *D, TypeSourceInfo *TInfo) {
    if (!TInfo)
      return;
    TypeLoc TL = TInfo->getTypeLoc();
    while (TL) {
      if (const QualifiedTypeLoc *QL = dyn_cast<QualifiedTypeLoc>(&TL)) {
        TL = QL->getUnqualifiedLoc();
      } else if (const AttributedTypeLoc *
                   Attr = dyn_cast<AttributedTypeLoc>(&TL)) {
        if (handleAttr(*Attr, D))
          break;
        TL = Attr->getModifiedLoc();
      } else if (const ArrayTypeLoc *Arr = dyn_cast<ArrayTypeLoc>(&TL)) {
        TL = Arr->getElementLoc();
      } else if (const PointerTypeLoc *PT = dyn_cast<PointerTypeLoc>(&TL)) {
        TL = PT->getPointeeLoc();
      } else if (const ReferenceTypeLoc *RT = dyn_cast<ReferenceTypeLoc>(&TL))
        TL = RT->getPointeeLoc();
      else
        break;
    }
  }

  bool handleAttr(AttributedTypeLoc TL, Decl *D = 0) {
    if (TL.getAttrKind() != AttributedType::attr_objc_ownership)
      return false;

    SourceLocation Loc = TL.getAttrNameLoc();
    unsigned RawLoc = Loc.getRawEncoding();
    if (MigrateCtx.AttrSet.count(RawLoc))
      return true;

    ASTContext &Ctx = MigrateCtx.Pass.Ctx;
    SourceManager &SM = Ctx.getSourceManager();
    if (Loc.isMacroID())
      Loc = SM.getImmediateExpansionRange(Loc).first;
    SmallString<32> Buf;
    bool Invalid = false;
    StringRef Spell = Lexer::getSpelling(
                                  SM.getSpellingLoc(TL.getAttrEnumOperandLoc()),
                                  Buf, SM, Ctx.getLangOpts(), &Invalid);
    if (Invalid)
      return false;
    MigrationContext::GCAttrOccurrence::AttrKind Kind;
    if (Spell == "strong")
      Kind = MigrationContext::GCAttrOccurrence::Strong;
    else if (Spell == "weak")
      Kind = MigrationContext::GCAttrOccurrence::Weak;
    else
      return false;
 
    MigrateCtx.AttrSet.insert(RawLoc);
    MigrateCtx.GCAttrs.push_back(MigrationContext::GCAttrOccurrence());
    MigrationContext::GCAttrOccurrence &Attr = MigrateCtx.GCAttrs.back();

    Attr.Kind = Kind;
    Attr.Loc = Loc;
    Attr.ModifiedType = TL.getModifiedLoc().getType();
    Attr.Dcl = D;
    Attr.FullyMigratable = FullyMigratable;
    return true;
  }

  bool isMigratable(Decl *D) {
    if (isa<TranslationUnitDecl>(D))
      return false;

    if (isInMainFile(D))
      return true;

    if (FunctionDecl *FD = dyn_cast<FunctionDecl>(D))
      return FD->hasBody();

    if (ObjCContainerDecl *ContD = dyn_cast<ObjCContainerDecl>(D))
      return hasObjCImpl(ContD);

    if (CXXRecordDecl *RD = dyn_cast<CXXRecordDecl>(D)) {
      for (CXXRecordDecl::method_iterator
             MI = RD->method_begin(), ME = RD->method_end(); MI != ME; ++MI) {
        if (MI->isOutOfLine())
          return true;
      }
      return false;
    }

    return isMigratable(cast<Decl>(D->getDeclContext()));
  }

  static bool hasObjCImpl(Decl *D) {
    if (!D)
      return false;
    if (ObjCContainerDecl *ContD = dyn_cast<ObjCContainerDecl>(D)) {
      if (ObjCInterfaceDecl *ID = dyn_cast<ObjCInterfaceDecl>(ContD))
        return ID->getImplementation() != 0;
      if (ObjCCategoryDecl *CD = dyn_cast<ObjCCategoryDecl>(ContD))
        return CD->getImplementation() != 0;
      if (isa<ObjCImplDecl>(ContD))
        return true;
      return false;
    }
    return false;
  }

  bool isInMainFile(Decl *D) {
    if (!D)
      return false;

    for (Decl::redecl_iterator
           I = D->redecls_begin(), E = D->redecls_end(); I != E; ++I)
      if (!isInMainFile(I->getLocation()))
        return false;
    
    return true;
  }

  bool isInMainFile(SourceLocation Loc) {
    if (Loc.isInvalid())
      return false;

    SourceManager &SM = MigrateCtx.Pass.Ctx.getSourceManager();
    return SM.isInFileID(SM.getExpansionLoc(Loc), SM.getMainFileID());
  }
};

} // anonymous namespace

static void errorForGCAttrsOnNonObjC(MigrationContext &MigrateCtx) {
  TransformActions &TA = MigrateCtx.Pass.TA;

  for (unsigned i = 0, e = MigrateCtx.GCAttrs.size(); i != e; ++i) {
    MigrationContext::GCAttrOccurrence &Attr = MigrateCtx.GCAttrs[i];
    if (Attr.FullyMigratable && Attr.Dcl) {
      if (Attr.ModifiedType.isNull())
        continue;
      if (!Attr.ModifiedType->isObjCRetainableType()) {
        TA.reportError("GC managed memory will become unmanaged in ARC",
                       Attr.Loc);
      }
    }
  }
}

static void checkWeakGCAttrs(MigrationContext &MigrateCtx) {
  TransformActions &TA = MigrateCtx.Pass.TA;

  for (unsigned i = 0, e = MigrateCtx.GCAttrs.size(); i != e; ++i) {
    MigrationContext::GCAttrOccurrence &Attr = MigrateCtx.GCAttrs[i];
    if (Attr.Kind == MigrationContext::GCAttrOccurrence::Weak) {
      if (Attr.ModifiedType.isNull() ||
          !Attr.ModifiedType->isObjCRetainableType())
        continue;
      if (!canApplyWeak(MigrateCtx.Pass.Ctx, Attr.ModifiedType,
                        /*AllowOnUnknownClass=*/true)) {
        Transaction Trans(TA);
        if (!MigrateCtx.RemovedAttrSet.count(Attr.Loc.getRawEncoding()))
          TA.replaceText(Attr.Loc, "__weak", "__unsafe_unretained");
        TA.clearDiagnostic(diag::err_arc_weak_no_runtime,
                           diag::err_arc_unsupported_weak_class,
                           Attr.Loc);
      }
    }
  }
}

typedef llvm::TinyPtrVector<ObjCPropertyDecl *> IndivPropsTy;

static void checkAllAtProps(MigrationContext &MigrateCtx,
                            SourceLocation AtLoc,
                            IndivPropsTy &IndProps) {
  if (IndProps.empty())
    return;

  for (IndivPropsTy::iterator
         PI = IndProps.begin(), PE = IndProps.end(); PI != PE; ++PI) {
    QualType T = (*PI)->getType();
    if (T.isNull() || !T->isObjCRetainableType())
      return;
  }

  SmallVector<std::pair<AttributedTypeLoc, ObjCPropertyDecl *>, 4> ATLs;
  bool hasWeak = false, hasStrong = false;
  ObjCPropertyDecl::PropertyAttributeKind
    Attrs = ObjCPropertyDecl::OBJC_PR_noattr;
  for (IndivPropsTy::iterator
         PI = IndProps.begin(), PE = IndProps.end(); PI != PE; ++PI) {
    ObjCPropertyDecl *PD = *PI;
    Attrs = PD->getPropertyAttributesAsWritten();
    TypeSourceInfo *TInfo = PD->getTypeSourceInfo();
    if (!TInfo)
      return;
    TypeLoc TL = TInfo->getTypeLoc();
    if (AttributedTypeLoc *ATL = dyn_cast<AttributedTypeLoc>(&TL)) {
      ATLs.push_back(std::make_pair(*ATL, PD));
      if (TInfo->getType().getObjCLifetime() == Qualifiers::OCL_Weak) {
        hasWeak = true;
      } else if (TInfo->getType().getObjCLifetime() == Qualifiers::OCL_Strong)
        hasStrong = true;
      else
        return;
    }
  }
  if (ATLs.empty())
    return;
  if (hasWeak && hasStrong)
    return;

  TransformActions &TA = MigrateCtx.Pass.TA;
  Transaction Trans(TA);

  if (GCAttrsCollector::hasObjCImpl(
                              cast<Decl>(IndProps.front()->getDeclContext()))) {
    if (hasWeak)
      MigrateCtx.AtPropsWeak.insert(AtLoc.getRawEncoding());

  } else {
    StringRef toAttr = "strong";
    if (hasWeak) {
      if (canApplyWeak(MigrateCtx.Pass.Ctx, IndProps.front()->getType(),
                       /*AllowOnUnkwownClass=*/true))
        toAttr = "weak";
      else
        toAttr = "unsafe_unretained";
    }
    if (Attrs & ObjCPropertyDecl::OBJC_PR_assign)
      MigrateCtx.rewritePropertyAttribute("assign", toAttr, AtLoc);
    else
      MigrateCtx.addPropertyAttribute(toAttr, AtLoc);
  }

  for (unsigned i = 0, e = ATLs.size(); i != e; ++i) {
    SourceLocation Loc = ATLs[i].first.getAttrNameLoc();
    if (Loc.isMacroID())
      Loc = MigrateCtx.Pass.Ctx.getSourceManager()
                                         .getImmediateExpansionRange(Loc).first;
    TA.remove(Loc);
    TA.clearDiagnostic(diag::err_objc_property_attr_mutually_exclusive, AtLoc);
    TA.clearDiagnostic(diag::err_arc_inconsistent_property_ownership,
                       ATLs[i].second->getLocation());
    MigrateCtx.RemovedAttrSet.insert(Loc.getRawEncoding());
  }
}

static void checkAllProps(MigrationContext &MigrateCtx,
                          std::vector<ObjCPropertyDecl *> &AllProps) {
  typedef llvm::TinyPtrVector<ObjCPropertyDecl *> IndivPropsTy;
  llvm::DenseMap<unsigned, IndivPropsTy> AtProps;

  for (unsigned i = 0, e = AllProps.size(); i != e; ++i) {
    ObjCPropertyDecl *PD = AllProps[i];
    if (PD->getPropertyAttributesAsWritten() &
          (ObjCPropertyDecl::OBJC_PR_assign |
           ObjCPropertyDecl::OBJC_PR_readonly)) {
      SourceLocation AtLoc = PD->getAtLoc();
      if (AtLoc.isInvalid())
        continue;
      unsigned RawAt = AtLoc.getRawEncoding();
      AtProps[RawAt].push_back(PD);
    }
  }

  for (llvm::DenseMap<unsigned, IndivPropsTy>::iterator
         I = AtProps.begin(), E = AtProps.end(); I != E; ++I) {
    SourceLocation AtLoc = SourceLocation::getFromRawEncoding(I->first);
    IndivPropsTy &IndProps = I->second;
    checkAllAtProps(MigrateCtx, AtLoc, IndProps);
  }
}

void GCAttrsTraverser::traverseTU(MigrationContext &MigrateCtx) {
  std::vector<ObjCPropertyDecl *> AllProps;
  GCAttrsCollector(MigrateCtx, AllProps).TraverseDecl(
                                  MigrateCtx.Pass.Ctx.getTranslationUnitDecl());

  errorForGCAttrsOnNonObjC(MigrateCtx);
  checkAllProps(MigrateCtx, AllProps);
  checkWeakGCAttrs(MigrateCtx);
}

void MigrationContext::dumpGCAttrs() {
  llvm::errs() << "\n################\n";
  for (unsigned i = 0, e = GCAttrs.size(); i != e; ++i) {
    GCAttrOccurrence &Attr = GCAttrs[i];
    llvm::errs() << "KIND: "
        << (Attr.Kind == GCAttrOccurrence::Strong ? "strong" : "weak");
    llvm::errs() << "\nLOC: ";
    Attr.Loc.dump(Pass.Ctx.getSourceManager());
    llvm::errs() << "\nTYPE: ";
    Attr.ModifiedType.dump();
    if (Attr.Dcl) {
      llvm::errs() << "DECL:\n";
      Attr.Dcl->dump();
    } else {
      llvm::errs() << "DECL: NONE";
    }
    llvm::errs() << "\nMIGRATABLE: " << Attr.FullyMigratable;
    llvm::errs() << "\n----------------\n";
  }
  llvm::errs() << "\n################\n";
}